SDXS:知识蒸馏在高效图像生成中的应用

2024-08-22 21:44

本文主要是介绍SDXS:知识蒸馏在高效图像生成中的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

扩散模型虽然在图像生成方面表现出色,但其迭代采样过程导致在低功耗设备上部署面临挑战,同时在云端高性能GPU平台上的能耗也不容忽视。为了解决这一问题,小米公司的Yuda Song、Zehao Sun、Xuanwu Yin等人提出了一种新的方法——SDXS,通过知识蒸馏简化了U-Net和图像解码器架构,并引入了一种创新的一步式DM训练技术,使用特征匹配和得分蒸馏,从而在单GPU上实现了大约100 FPS(比SD v1.5快30倍)和30 FPS(比SDXL快60倍)的推理速度。

图1为在图像生成时间限制为1秒的情况下,不同模型的性能对比。SDXL模型在这种情况下只能使用16次函数评估(NFEs)来生成稍微模糊的图像,而提出的SDXS-1024模型却能够生成30张清晰的图像。这表明SDXS-1024在保持图像质量的同时显著提高了生成速度。本方法还能够训练ControlNet,这是一种能够嵌入空间引导的网络,用于图像到图像的任务,如草图到图像的转换、修复和超分辨率等。证明了SDXS方法的灵活性和应用潜力。

方法

LDM框架由三个关键要素组成:文本编码器、图像解码器以及一个需要多次迭代以生成清晰图像的去噪模型。由于文本编码器的开销相对较低,因此优化其大小并不是研究的重点。

VAE优化:LDM框架通过将样本投影到计算效率更高的低维潜在空间,显著提高了高分辨率图像扩散模型的训练效率。这一过程通过使用预训练模型,如变分自编码器(Variational AutoEncoder, VAE)或向量量化变分自编码器(Vector Quantised-Variational AutoEncoder, VQVAE)来实现高比例图像压缩。VAE包含一个将图像映射到潜在空间的编码器,以及一个重建图像的解码器。其训练通过平衡重建损失、Kullback-Leibler (KL) 散度和GAN损失来优化。然而,训练中对所有样本同等对待引入了冗余。研究者们提出了一种VAE蒸馏(VD)损失,用于训练一个小型的图像解码器G: 其中,D是GAN判别器,用于平衡两个损失项,表示在8倍下采样图像上的L1损失。图2(a)展示了蒸馏小型图像解码器的训练策略。倡使用简化的CNN架构,不包含注意力机制和归一化层等复杂组件,只关注基本的残差块和上采样层。

U-Net优化: LDMs采用U-Net架构作为核心去噪模型,该架构结合了残差块和Transformer块。为了利用预训练的U-Nets的能力,同时减少计算需求和参数数量,研究者们采用了知识蒸馏策略,这一策略受到BK-SDM的块移除训练策略启发。这涉及从U-Net中选择性地移除残差和Transformer块,目的是训练一个更紧凑的模型,该模型仍能有效复现原始模型的中间特征图和输出。图2(b)展示了蒸馏小型U-Net的训练策略。知识蒸馏通过输出知识蒸馏(OKD)和特征知识蒸馏(FKD)损失实现:总的损失函数是两者的结合: 其中,λF​平衡两个损失项。与BK-SDM不同,研究者们排除了原始的去噪损失。模型基于SD-2.1基础版和SDXL-1.0基础版进行了小型化。对于SD-2.1基础版,研究者们去除了中间阶段、下采样阶段的最后阶段和上采样阶段的第一阶段,并去除了最高分辨率阶段的Transformer块。对于SDXL-1.0基础版,研究者们去除了大部分Transformer块。

ControlNet优化: ControlNet通过嵌入空间引导来增强扩散模型,使图像到图像的任务如草图到图像的转换、修复和超分辨率成为可能。它复制了U-Net的编码器架构和参数,并增加了额外的卷积层以纳入空间控制。尽管ControlNet继承了U-Net的参数并采用零卷积来提高训练稳定性,但其训练过程仍然成本高昂且显著受数据集质量影响。为了解决这些挑战,研究者们提出了一种蒸馏方法,将原始U-Net中的ControlNet蒸馏到小型U-Net中的相应ControlNet。图2(b)展示了这一过程,不是直接蒸馏ControlNet零卷积的输出,而是将ControlNet与U-Net结合,然后蒸馏U-Net的中间特征图和输出,这使得蒸馏后的ControlNet和小型U-Net能够更好地协同工作。考虑到ControlNet不影响U-Net编码器的特征图,特征蒸馏仅应用于U-Net的解码器。

尽管扩散模型(DMs)在图像生成方面表现出色,但它们依赖于多个采样步骤,即使采用先进的采样器,这也引入了显著的推理延迟。为了解决这个问题,先前的研究引入了知识蒸馏技术,例如渐进式蒸馏(progressive distillation)和一致性蒸馏(consistency distillation),旨在减少采样步骤并加速推理。然而,这些方法通常只能在4到8个采样步骤中产生清晰的图像,这与在生成对抗网络(GANs)中看到的一步式生成过程形成了鲜明对比。

直接训练一步式模型的方法包括初始化噪声ϵ,并使用常微分方程(ODE)采样器ψ进行采样以获得生成的图像,从而构建噪声-图像对。这些对在训练期间作为学生模型的输入和真实情况。然而,这种方法通常导致生成质量低下的图像。根本问题是使用预训练的DM生成的噪声-图像对的采样轨迹交叉,导致不适定问题。Rectified Flow通过拉直采样轨迹来解决这一挑战。它替换了训练目标,并提出了一种“重流”策略来优化配对,从而最小化轨迹交叉。

采样轨迹的交叉可能导致一个噪声输入对应多个真实图像,导致训练模型生成的图像是多个可行输出的加权和。为了解决这个问题,研究者们探索了改变权重方案以优先考虑更清晰图像的替代损失函数。在大多数情况下,可以使用L1损失、感知损失和LPIPS损失来改变权重形式。研究者们基于特征匹配的方法,计算由编码器模型生成的中间特征图上损失。具体来说,他们从DISTS损失中汲取灵感,对这些特征图应用结构相似性指数(SSIM)以获得更精细的特征匹配损失: 其中 是由编码器 编码的第 个中间特征图上计算的SSIM损失的权重,是由小型U-Net 生成的图像,是由原始U-Net xϕ​ 使用ODE采样器ψ生成的图像。在实践中,使用预训练的CNN骨干、ViT骨干和DM U-Net的编码器都能产生有利的结果,与MSE损失的比较在图6中展示。

尽管特征匹配损失可以产生几乎清晰的图像,但它未能实现真正的分布匹配,因此训练的模型只能作为正式训练的初始化。为了解决这一差距,Diff-Instruct中使用的训练策略,该策略旨在通过在时间步上匹配边际得分函数,使模型的输出分布与预训练模型的分布更紧密地对齐。然而,因为它需要在 t→T 时添加高水平的噪声以使目标得分可计算,此时估计的得分函数是不准确的。研究者们指出,扩散模型的采样轨迹从粗糙到精细,这意味着 t→T 时,得分函数提供了低频信息的梯度,而 t→0 时,它提供了高频信息的梯度。因此,研究者们将时间步分为两段:,后者被LFM替换,因为它可以提供足够的低频梯度。这种策略可以正式表示为: 其中 是在时间 t 和状态 下的函数,用于平衡两段的梯度,。研究者们有意将 α 设置接近1,并将 设置在高值,以确保模型的输出分布与预训练得分函数预测的分布平滑对齐。在概率密度显著重叠后,逐渐降低 α 和 。图3描述了训练策略,其中离线DM表示预训练DM的U-Net,在线DM是从离线DM初始化并在生成的图像上通过等式(1)微调得到的。在实践中,在线DM和学生DM交替训练,如算法1所示。

 一旦一步式DM训练完成,就可以像其他DM一样进行微调,以调整生成图像的风格。研究者们结合使用LoRA和提出的分段得分蒸馏来微调一步式DM,如图4所示。具体为将预训练的LoRA插入离线DM中,如果它也与教师DM兼容,也会插入到那里。要注意,不将LoRA插入在线DM中,因为它对应于一步式DM的输出分布。然后,使用与一步式训练相同的训练程序,但跳过特征匹配预热,因为LoRA微调比完全微调更稳定。另外当教师DM不能纳入预训练的LoRA时,使用降低的 。通过这种方式,可以将预训练的LoRA蒸馏到SDXS的LoRA中。

研究者们的方法也可以适应于ControlNet的训练,使微小的一步式模型能够在其图像生成过程中纳入图像条件,如图5所示。与用于文本到图像生成的基础模型相比,这里训练的模型是伴随前面提到的小型U-Net的蒸馏ControlNet,并且在训练期间U-Net的参数是固定的。重点是需要从教师模型采样的图像中提取控制图像,而不是从数据集图像中提取,以确保噪声、目标图像和控制图像形成一个配对三元组。此外,原始多步U-Net的伴随预训练ControlNet与在线U-Net和离线U-Net集成,但不参与训练。与文本编码器类似,其功能限于作为预训练的特征提取器。通过这种方式,为了进一步减少损失L,训练的ControlNet学习利用从目标图像中提取的控制图像。同时,得分蒸馏鼓励模型匹配边际分布,增强生成图像的上下文相关性。值得注意的是,研究发现用新初始化的噪声替换U-Net噪声输入的一部分可以增强控制能力。图5展示了基于特征匹配和得分蒸馏提出的一步式ControlNet训练策略。虚线表示梯度反向传播。

实验

研究者的代码是基于diffusers库开发的。由于他们无法访问SD v2.1基础版和SDXL的训练数据集,整个训练过程几乎是无数据的,完全依赖于公开可访问数据集中提供的提示。他们使用开源的预训练模型与这些提示结合,生成相应的图像。为了训练模型,他们将训练小批量大小配置在1,024到2,048之间。为了在现有硬件上适应这个批量大小,必要时他们有策略地实施了梯度累积。他们发现所提出训练策略导致模型生成的图像纹理较少。因此,在训练后,他们使用GAN损失结合极低秩的LoRA进行了短暂的微调。当需要GAN损失时,他们使用了StyleGAN-T中的Projected GAN损失,基本设置与ADD一致。对于SDXS-1024的训练,他们使用Vega,SDXL的紧凑版本,作为在线DM和离线DM的初始化,以减少训练开销。

表3为在MS-COCO 2017验证集上的定量结果,即FID和CLIP分数。由于FID对高斯分布的强烈假设,它不是衡量图像质量的一个好的指标,因为它受到生成样本多样性的显著影响。表3显示了MS-COCO 2017 5K子集上的性能比较,图7显示了一些示例。尽管模型大小和所需的采样步骤数量都有明显减少,但SDXS-512的提示跟随能力仍然优于SD v1.5。与Tiny SD(另一个为效率而设计的模型)相比,SDXS-512的优越性更加明显。这一观察结果也在SDXS-1024的性能中得到了一致的验证。使用所提方法训练LoRA的样本如图9所示。显然,模型生成的图像风格可以有效地转移到与离线DM集成的风格导向LoRA匹配的风格,同时通常保持场景布局的一致性。

研究者引入的一步式训练方法是足够通用的,可以应用于图像条件生成。他们展示了其在促进图像到图像转换方面的有效性,特别是利用ControlNet进行涉及canny边缘和深度图的转换。图8展示了两个不同任务的代表性示例,突出了生成图像紧密遵循控制图像提供的指导的能力。然而,这也揭示了在图像多样性方面的显著局限性。如图1所示,虽然问题可以通过替换提示来缓解,但它仍然是后续研究工作中加强的领域。

实验证明将高效的图像条件生成部署在边缘设备上是一个充满前景的研究方向,研究者计划在未来探索包括修复和超分辨率在内的更多应用。通过不断的技术创新和优化,人工智能在图像生成领域的应用将更加广泛和深入。

论文链接:https://arxiv.org/abs/2403.16627

项目地址:https://idkiro.github.io/sdxs/

这篇关于SDXS:知识蒸馏在高效图像生成中的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1097483

相关文章

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

Python中随机休眠技术原理与应用详解

《Python中随机休眠技术原理与应用详解》在编程中,让程序暂停执行特定时间是常见需求,当需要引入不确定性时,随机休眠就成为关键技巧,下面我们就来看看Python中随机休眠技术的具体实现与应用吧... 目录引言一、实现原理与基础方法1.1 核心函数解析1.2 基础实现模板1.3 整数版实现二、典型应用场景2

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Android Kotlin 高阶函数详解及其在协程中的应用小结

《AndroidKotlin高阶函数详解及其在协程中的应用小结》高阶函数是Kotlin中的一个重要特性,它能够将函数作为一等公民(First-ClassCitizen),使得代码更加简洁、灵活和可... 目录1. 引言2. 什么是高阶函数?3. 高阶函数的基础用法3.1 传递函数作为参数3.2 Lambda

Java中&和&&以及|和||的区别、应用场景和代码示例

《Java中&和&&以及|和||的区别、应用场景和代码示例》:本文主要介绍Java中的逻辑运算符&、&&、|和||的区别,包括它们在布尔和整数类型上的应用,文中通过代码介绍的非常详细,需要的朋友可... 目录前言1. & 和 &&代码示例2. | 和 ||代码示例3. 为什么要使用 & 和 | 而不是总是使

Python循环缓冲区的应用详解

《Python循环缓冲区的应用详解》循环缓冲区是一个线性缓冲区,逻辑上被视为一个循环的结构,本文主要为大家介绍了Python中循环缓冲区的相关应用,有兴趣的小伙伴可以了解一下... 目录什么是循环缓冲区循环缓冲区的结构python中的循环缓冲区实现运行循环缓冲区循环缓冲区的优势应用案例Python中的实现库

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解