SDXS:知识蒸馏在高效图像生成中的应用

2024-08-22 21:44

本文主要是介绍SDXS:知识蒸馏在高效图像生成中的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

扩散模型虽然在图像生成方面表现出色,但其迭代采样过程导致在低功耗设备上部署面临挑战,同时在云端高性能GPU平台上的能耗也不容忽视。为了解决这一问题,小米公司的Yuda Song、Zehao Sun、Xuanwu Yin等人提出了一种新的方法——SDXS,通过知识蒸馏简化了U-Net和图像解码器架构,并引入了一种创新的一步式DM训练技术,使用特征匹配和得分蒸馏,从而在单GPU上实现了大约100 FPS(比SD v1.5快30倍)和30 FPS(比SDXL快60倍)的推理速度。

图1为在图像生成时间限制为1秒的情况下,不同模型的性能对比。SDXL模型在这种情况下只能使用16次函数评估(NFEs)来生成稍微模糊的图像,而提出的SDXS-1024模型却能够生成30张清晰的图像。这表明SDXS-1024在保持图像质量的同时显著提高了生成速度。本方法还能够训练ControlNet,这是一种能够嵌入空间引导的网络,用于图像到图像的任务,如草图到图像的转换、修复和超分辨率等。证明了SDXS方法的灵活性和应用潜力。

方法

LDM框架由三个关键要素组成:文本编码器、图像解码器以及一个需要多次迭代以生成清晰图像的去噪模型。由于文本编码器的开销相对较低,因此优化其大小并不是研究的重点。

VAE优化:LDM框架通过将样本投影到计算效率更高的低维潜在空间,显著提高了高分辨率图像扩散模型的训练效率。这一过程通过使用预训练模型,如变分自编码器(Variational AutoEncoder, VAE)或向量量化变分自编码器(Vector Quantised-Variational AutoEncoder, VQVAE)来实现高比例图像压缩。VAE包含一个将图像映射到潜在空间的编码器,以及一个重建图像的解码器。其训练通过平衡重建损失、Kullback-Leibler (KL) 散度和GAN损失来优化。然而,训练中对所有样本同等对待引入了冗余。研究者们提出了一种VAE蒸馏(VD)损失,用于训练一个小型的图像解码器G: 其中,D是GAN判别器,用于平衡两个损失项,表示在8倍下采样图像上的L1损失。图2(a)展示了蒸馏小型图像解码器的训练策略。倡使用简化的CNN架构,不包含注意力机制和归一化层等复杂组件,只关注基本的残差块和上采样层。

U-Net优化: LDMs采用U-Net架构作为核心去噪模型,该架构结合了残差块和Transformer块。为了利用预训练的U-Nets的能力,同时减少计算需求和参数数量,研究者们采用了知识蒸馏策略,这一策略受到BK-SDM的块移除训练策略启发。这涉及从U-Net中选择性地移除残差和Transformer块,目的是训练一个更紧凑的模型,该模型仍能有效复现原始模型的中间特征图和输出。图2(b)展示了蒸馏小型U-Net的训练策略。知识蒸馏通过输出知识蒸馏(OKD)和特征知识蒸馏(FKD)损失实现:总的损失函数是两者的结合: 其中,λF​平衡两个损失项。与BK-SDM不同,研究者们排除了原始的去噪损失。模型基于SD-2.1基础版和SDXL-1.0基础版进行了小型化。对于SD-2.1基础版,研究者们去除了中间阶段、下采样阶段的最后阶段和上采样阶段的第一阶段,并去除了最高分辨率阶段的Transformer块。对于SDXL-1.0基础版,研究者们去除了大部分Transformer块。

ControlNet优化: ControlNet通过嵌入空间引导来增强扩散模型,使图像到图像的任务如草图到图像的转换、修复和超分辨率成为可能。它复制了U-Net的编码器架构和参数,并增加了额外的卷积层以纳入空间控制。尽管ControlNet继承了U-Net的参数并采用零卷积来提高训练稳定性,但其训练过程仍然成本高昂且显著受数据集质量影响。为了解决这些挑战,研究者们提出了一种蒸馏方法,将原始U-Net中的ControlNet蒸馏到小型U-Net中的相应ControlNet。图2(b)展示了这一过程,不是直接蒸馏ControlNet零卷积的输出,而是将ControlNet与U-Net结合,然后蒸馏U-Net的中间特征图和输出,这使得蒸馏后的ControlNet和小型U-Net能够更好地协同工作。考虑到ControlNet不影响U-Net编码器的特征图,特征蒸馏仅应用于U-Net的解码器。

尽管扩散模型(DMs)在图像生成方面表现出色,但它们依赖于多个采样步骤,即使采用先进的采样器,这也引入了显著的推理延迟。为了解决这个问题,先前的研究引入了知识蒸馏技术,例如渐进式蒸馏(progressive distillation)和一致性蒸馏(consistency distillation),旨在减少采样步骤并加速推理。然而,这些方法通常只能在4到8个采样步骤中产生清晰的图像,这与在生成对抗网络(GANs)中看到的一步式生成过程形成了鲜明对比。

直接训练一步式模型的方法包括初始化噪声ϵ,并使用常微分方程(ODE)采样器ψ进行采样以获得生成的图像,从而构建噪声-图像对。这些对在训练期间作为学生模型的输入和真实情况。然而,这种方法通常导致生成质量低下的图像。根本问题是使用预训练的DM生成的噪声-图像对的采样轨迹交叉,导致不适定问题。Rectified Flow通过拉直采样轨迹来解决这一挑战。它替换了训练目标,并提出了一种“重流”策略来优化配对,从而最小化轨迹交叉。

采样轨迹的交叉可能导致一个噪声输入对应多个真实图像,导致训练模型生成的图像是多个可行输出的加权和。为了解决这个问题,研究者们探索了改变权重方案以优先考虑更清晰图像的替代损失函数。在大多数情况下,可以使用L1损失、感知损失和LPIPS损失来改变权重形式。研究者们基于特征匹配的方法,计算由编码器模型生成的中间特征图上损失。具体来说,他们从DISTS损失中汲取灵感,对这些特征图应用结构相似性指数(SSIM)以获得更精细的特征匹配损失: 其中 是由编码器 编码的第 个中间特征图上计算的SSIM损失的权重,是由小型U-Net 生成的图像,是由原始U-Net xϕ​ 使用ODE采样器ψ生成的图像。在实践中,使用预训练的CNN骨干、ViT骨干和DM U-Net的编码器都能产生有利的结果,与MSE损失的比较在图6中展示。

尽管特征匹配损失可以产生几乎清晰的图像,但它未能实现真正的分布匹配,因此训练的模型只能作为正式训练的初始化。为了解决这一差距,Diff-Instruct中使用的训练策略,该策略旨在通过在时间步上匹配边际得分函数,使模型的输出分布与预训练模型的分布更紧密地对齐。然而,因为它需要在 t→T 时添加高水平的噪声以使目标得分可计算,此时估计的得分函数是不准确的。研究者们指出,扩散模型的采样轨迹从粗糙到精细,这意味着 t→T 时,得分函数提供了低频信息的梯度,而 t→0 时,它提供了高频信息的梯度。因此,研究者们将时间步分为两段:,后者被LFM替换,因为它可以提供足够的低频梯度。这种策略可以正式表示为: 其中 是在时间 t 和状态 下的函数,用于平衡两段的梯度,。研究者们有意将 α 设置接近1,并将 设置在高值,以确保模型的输出分布与预训练得分函数预测的分布平滑对齐。在概率密度显著重叠后,逐渐降低 α 和 。图3描述了训练策略,其中离线DM表示预训练DM的U-Net,在线DM是从离线DM初始化并在生成的图像上通过等式(1)微调得到的。在实践中,在线DM和学生DM交替训练,如算法1所示。

 一旦一步式DM训练完成,就可以像其他DM一样进行微调,以调整生成图像的风格。研究者们结合使用LoRA和提出的分段得分蒸馏来微调一步式DM,如图4所示。具体为将预训练的LoRA插入离线DM中,如果它也与教师DM兼容,也会插入到那里。要注意,不将LoRA插入在线DM中,因为它对应于一步式DM的输出分布。然后,使用与一步式训练相同的训练程序,但跳过特征匹配预热,因为LoRA微调比完全微调更稳定。另外当教师DM不能纳入预训练的LoRA时,使用降低的 。通过这种方式,可以将预训练的LoRA蒸馏到SDXS的LoRA中。

研究者们的方法也可以适应于ControlNet的训练,使微小的一步式模型能够在其图像生成过程中纳入图像条件,如图5所示。与用于文本到图像生成的基础模型相比,这里训练的模型是伴随前面提到的小型U-Net的蒸馏ControlNet,并且在训练期间U-Net的参数是固定的。重点是需要从教师模型采样的图像中提取控制图像,而不是从数据集图像中提取,以确保噪声、目标图像和控制图像形成一个配对三元组。此外,原始多步U-Net的伴随预训练ControlNet与在线U-Net和离线U-Net集成,但不参与训练。与文本编码器类似,其功能限于作为预训练的特征提取器。通过这种方式,为了进一步减少损失L,训练的ControlNet学习利用从目标图像中提取的控制图像。同时,得分蒸馏鼓励模型匹配边际分布,增强生成图像的上下文相关性。值得注意的是,研究发现用新初始化的噪声替换U-Net噪声输入的一部分可以增强控制能力。图5展示了基于特征匹配和得分蒸馏提出的一步式ControlNet训练策略。虚线表示梯度反向传播。

实验

研究者的代码是基于diffusers库开发的。由于他们无法访问SD v2.1基础版和SDXL的训练数据集,整个训练过程几乎是无数据的,完全依赖于公开可访问数据集中提供的提示。他们使用开源的预训练模型与这些提示结合,生成相应的图像。为了训练模型,他们将训练小批量大小配置在1,024到2,048之间。为了在现有硬件上适应这个批量大小,必要时他们有策略地实施了梯度累积。他们发现所提出训练策略导致模型生成的图像纹理较少。因此,在训练后,他们使用GAN损失结合极低秩的LoRA进行了短暂的微调。当需要GAN损失时,他们使用了StyleGAN-T中的Projected GAN损失,基本设置与ADD一致。对于SDXS-1024的训练,他们使用Vega,SDXL的紧凑版本,作为在线DM和离线DM的初始化,以减少训练开销。

表3为在MS-COCO 2017验证集上的定量结果,即FID和CLIP分数。由于FID对高斯分布的强烈假设,它不是衡量图像质量的一个好的指标,因为它受到生成样本多样性的显著影响。表3显示了MS-COCO 2017 5K子集上的性能比较,图7显示了一些示例。尽管模型大小和所需的采样步骤数量都有明显减少,但SDXS-512的提示跟随能力仍然优于SD v1.5。与Tiny SD(另一个为效率而设计的模型)相比,SDXS-512的优越性更加明显。这一观察结果也在SDXS-1024的性能中得到了一致的验证。使用所提方法训练LoRA的样本如图9所示。显然,模型生成的图像风格可以有效地转移到与离线DM集成的风格导向LoRA匹配的风格,同时通常保持场景布局的一致性。

研究者引入的一步式训练方法是足够通用的,可以应用于图像条件生成。他们展示了其在促进图像到图像转换方面的有效性,特别是利用ControlNet进行涉及canny边缘和深度图的转换。图8展示了两个不同任务的代表性示例,突出了生成图像紧密遵循控制图像提供的指导的能力。然而,这也揭示了在图像多样性方面的显著局限性。如图1所示,虽然问题可以通过替换提示来缓解,但它仍然是后续研究工作中加强的领域。

实验证明将高效的图像条件生成部署在边缘设备上是一个充满前景的研究方向,研究者计划在未来探索包括修复和超分辨率在内的更多应用。通过不断的技术创新和优化,人工智能在图像生成领域的应用将更加广泛和深入。

论文链接:https://arxiv.org/abs/2403.16627

项目地址:https://idkiro.github.io/sdxs/

这篇关于SDXS:知识蒸馏在高效图像生成中的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1097483

相关文章

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

高效+灵活,万博智云全球发布AWS无代理跨云容灾方案!

摘要 近日,万博智云推出了基于AWS的无代理跨云容灾解决方案,并与拉丁美洲,中东,亚洲的合作伙伴面向全球开展了联合发布。这一方案以AWS应用环境为基础,将HyperBDR平台的高效、灵活和成本效益优势与无代理功能相结合,为全球企业带来实现了更便捷、经济的数据保护。 一、全球联合发布 9月2日,万博智云CEO Michael Wong在线上平台发布AWS无代理跨云容灾解决方案的阐述视频,介绍了

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu1394(线段树点更新的应用)

题意:求一个序列经过一定的操作得到的序列的最小逆序数 这题会用到逆序数的一个性质,在0到n-1这些数字组成的乱序排列,将第一个数字A移到最后一位,得到的逆序数为res-a+(n-a-1) 知道上面的知识点后,可以用暴力来解 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#in

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

sqlite3 相关知识

WAL 模式 VS 回滚模式 特性WAL 模式回滚模式(Rollback Journal)定义使用写前日志来记录变更。使用回滚日志来记录事务的所有修改。特点更高的并发性和性能;支持多读者和单写者。支持安全的事务回滚,但并发性较低。性能写入性能更好,尤其是读多写少的场景。写操作会造成较大的性能开销,尤其是在事务开始时。写入流程数据首先写入 WAL 文件,然后才从 WAL 刷新到主数据库。数据在开始