SDXS:知识蒸馏在高效图像生成中的应用

2024-08-22 21:44

本文主要是介绍SDXS:知识蒸馏在高效图像生成中的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

扩散模型虽然在图像生成方面表现出色,但其迭代采样过程导致在低功耗设备上部署面临挑战,同时在云端高性能GPU平台上的能耗也不容忽视。为了解决这一问题,小米公司的Yuda Song、Zehao Sun、Xuanwu Yin等人提出了一种新的方法——SDXS,通过知识蒸馏简化了U-Net和图像解码器架构,并引入了一种创新的一步式DM训练技术,使用特征匹配和得分蒸馏,从而在单GPU上实现了大约100 FPS(比SD v1.5快30倍)和30 FPS(比SDXL快60倍)的推理速度。

图1为在图像生成时间限制为1秒的情况下,不同模型的性能对比。SDXL模型在这种情况下只能使用16次函数评估(NFEs)来生成稍微模糊的图像,而提出的SDXS-1024模型却能够生成30张清晰的图像。这表明SDXS-1024在保持图像质量的同时显著提高了生成速度。本方法还能够训练ControlNet,这是一种能够嵌入空间引导的网络,用于图像到图像的任务,如草图到图像的转换、修复和超分辨率等。证明了SDXS方法的灵活性和应用潜力。

方法

LDM框架由三个关键要素组成:文本编码器、图像解码器以及一个需要多次迭代以生成清晰图像的去噪模型。由于文本编码器的开销相对较低,因此优化其大小并不是研究的重点。

VAE优化:LDM框架通过将样本投影到计算效率更高的低维潜在空间,显著提高了高分辨率图像扩散模型的训练效率。这一过程通过使用预训练模型,如变分自编码器(Variational AutoEncoder, VAE)或向量量化变分自编码器(Vector Quantised-Variational AutoEncoder, VQVAE)来实现高比例图像压缩。VAE包含一个将图像映射到潜在空间的编码器,以及一个重建图像的解码器。其训练通过平衡重建损失、Kullback-Leibler (KL) 散度和GAN损失来优化。然而,训练中对所有样本同等对待引入了冗余。研究者们提出了一种VAE蒸馏(VD)损失,用于训练一个小型的图像解码器G: 其中,D是GAN判别器,用于平衡两个损失项,表示在8倍下采样图像上的L1损失。图2(a)展示了蒸馏小型图像解码器的训练策略。倡使用简化的CNN架构,不包含注意力机制和归一化层等复杂组件,只关注基本的残差块和上采样层。

U-Net优化: LDMs采用U-Net架构作为核心去噪模型,该架构结合了残差块和Transformer块。为了利用预训练的U-Nets的能力,同时减少计算需求和参数数量,研究者们采用了知识蒸馏策略,这一策略受到BK-SDM的块移除训练策略启发。这涉及从U-Net中选择性地移除残差和Transformer块,目的是训练一个更紧凑的模型,该模型仍能有效复现原始模型的中间特征图和输出。图2(b)展示了蒸馏小型U-Net的训练策略。知识蒸馏通过输出知识蒸馏(OKD)和特征知识蒸馏(FKD)损失实现:总的损失函数是两者的结合: 其中,λF​平衡两个损失项。与BK-SDM不同,研究者们排除了原始的去噪损失。模型基于SD-2.1基础版和SDXL-1.0基础版进行了小型化。对于SD-2.1基础版,研究者们去除了中间阶段、下采样阶段的最后阶段和上采样阶段的第一阶段,并去除了最高分辨率阶段的Transformer块。对于SDXL-1.0基础版,研究者们去除了大部分Transformer块。

ControlNet优化: ControlNet通过嵌入空间引导来增强扩散模型,使图像到图像的任务如草图到图像的转换、修复和超分辨率成为可能。它复制了U-Net的编码器架构和参数,并增加了额外的卷积层以纳入空间控制。尽管ControlNet继承了U-Net的参数并采用零卷积来提高训练稳定性,但其训练过程仍然成本高昂且显著受数据集质量影响。为了解决这些挑战,研究者们提出了一种蒸馏方法,将原始U-Net中的ControlNet蒸馏到小型U-Net中的相应ControlNet。图2(b)展示了这一过程,不是直接蒸馏ControlNet零卷积的输出,而是将ControlNet与U-Net结合,然后蒸馏U-Net的中间特征图和输出,这使得蒸馏后的ControlNet和小型U-Net能够更好地协同工作。考虑到ControlNet不影响U-Net编码器的特征图,特征蒸馏仅应用于U-Net的解码器。

尽管扩散模型(DMs)在图像生成方面表现出色,但它们依赖于多个采样步骤,即使采用先进的采样器,这也引入了显著的推理延迟。为了解决这个问题,先前的研究引入了知识蒸馏技术,例如渐进式蒸馏(progressive distillation)和一致性蒸馏(consistency distillation),旨在减少采样步骤并加速推理。然而,这些方法通常只能在4到8个采样步骤中产生清晰的图像,这与在生成对抗网络(GANs)中看到的一步式生成过程形成了鲜明对比。

直接训练一步式模型的方法包括初始化噪声ϵ,并使用常微分方程(ODE)采样器ψ进行采样以获得生成的图像,从而构建噪声-图像对。这些对在训练期间作为学生模型的输入和真实情况。然而,这种方法通常导致生成质量低下的图像。根本问题是使用预训练的DM生成的噪声-图像对的采样轨迹交叉,导致不适定问题。Rectified Flow通过拉直采样轨迹来解决这一挑战。它替换了训练目标,并提出了一种“重流”策略来优化配对,从而最小化轨迹交叉。

采样轨迹的交叉可能导致一个噪声输入对应多个真实图像,导致训练模型生成的图像是多个可行输出的加权和。为了解决这个问题,研究者们探索了改变权重方案以优先考虑更清晰图像的替代损失函数。在大多数情况下,可以使用L1损失、感知损失和LPIPS损失来改变权重形式。研究者们基于特征匹配的方法,计算由编码器模型生成的中间特征图上损失。具体来说,他们从DISTS损失中汲取灵感,对这些特征图应用结构相似性指数(SSIM)以获得更精细的特征匹配损失: 其中 是由编码器 编码的第 个中间特征图上计算的SSIM损失的权重,是由小型U-Net 生成的图像,是由原始U-Net xϕ​ 使用ODE采样器ψ生成的图像。在实践中,使用预训练的CNN骨干、ViT骨干和DM U-Net的编码器都能产生有利的结果,与MSE损失的比较在图6中展示。

尽管特征匹配损失可以产生几乎清晰的图像,但它未能实现真正的分布匹配,因此训练的模型只能作为正式训练的初始化。为了解决这一差距,Diff-Instruct中使用的训练策略,该策略旨在通过在时间步上匹配边际得分函数,使模型的输出分布与预训练模型的分布更紧密地对齐。然而,因为它需要在 t→T 时添加高水平的噪声以使目标得分可计算,此时估计的得分函数是不准确的。研究者们指出,扩散模型的采样轨迹从粗糙到精细,这意味着 t→T 时,得分函数提供了低频信息的梯度,而 t→0 时,它提供了高频信息的梯度。因此,研究者们将时间步分为两段:,后者被LFM替换,因为它可以提供足够的低频梯度。这种策略可以正式表示为: 其中 是在时间 t 和状态 下的函数,用于平衡两段的梯度,。研究者们有意将 α 设置接近1,并将 设置在高值,以确保模型的输出分布与预训练得分函数预测的分布平滑对齐。在概率密度显著重叠后,逐渐降低 α 和 。图3描述了训练策略,其中离线DM表示预训练DM的U-Net,在线DM是从离线DM初始化并在生成的图像上通过等式(1)微调得到的。在实践中,在线DM和学生DM交替训练,如算法1所示。

 一旦一步式DM训练完成,就可以像其他DM一样进行微调,以调整生成图像的风格。研究者们结合使用LoRA和提出的分段得分蒸馏来微调一步式DM,如图4所示。具体为将预训练的LoRA插入离线DM中,如果它也与教师DM兼容,也会插入到那里。要注意,不将LoRA插入在线DM中,因为它对应于一步式DM的输出分布。然后,使用与一步式训练相同的训练程序,但跳过特征匹配预热,因为LoRA微调比完全微调更稳定。另外当教师DM不能纳入预训练的LoRA时,使用降低的 。通过这种方式,可以将预训练的LoRA蒸馏到SDXS的LoRA中。

研究者们的方法也可以适应于ControlNet的训练,使微小的一步式模型能够在其图像生成过程中纳入图像条件,如图5所示。与用于文本到图像生成的基础模型相比,这里训练的模型是伴随前面提到的小型U-Net的蒸馏ControlNet,并且在训练期间U-Net的参数是固定的。重点是需要从教师模型采样的图像中提取控制图像,而不是从数据集图像中提取,以确保噪声、目标图像和控制图像形成一个配对三元组。此外,原始多步U-Net的伴随预训练ControlNet与在线U-Net和离线U-Net集成,但不参与训练。与文本编码器类似,其功能限于作为预训练的特征提取器。通过这种方式,为了进一步减少损失L,训练的ControlNet学习利用从目标图像中提取的控制图像。同时,得分蒸馏鼓励模型匹配边际分布,增强生成图像的上下文相关性。值得注意的是,研究发现用新初始化的噪声替换U-Net噪声输入的一部分可以增强控制能力。图5展示了基于特征匹配和得分蒸馏提出的一步式ControlNet训练策略。虚线表示梯度反向传播。

实验

研究者的代码是基于diffusers库开发的。由于他们无法访问SD v2.1基础版和SDXL的训练数据集,整个训练过程几乎是无数据的,完全依赖于公开可访问数据集中提供的提示。他们使用开源的预训练模型与这些提示结合,生成相应的图像。为了训练模型,他们将训练小批量大小配置在1,024到2,048之间。为了在现有硬件上适应这个批量大小,必要时他们有策略地实施了梯度累积。他们发现所提出训练策略导致模型生成的图像纹理较少。因此,在训练后,他们使用GAN损失结合极低秩的LoRA进行了短暂的微调。当需要GAN损失时,他们使用了StyleGAN-T中的Projected GAN损失,基本设置与ADD一致。对于SDXS-1024的训练,他们使用Vega,SDXL的紧凑版本,作为在线DM和离线DM的初始化,以减少训练开销。

表3为在MS-COCO 2017验证集上的定量结果,即FID和CLIP分数。由于FID对高斯分布的强烈假设,它不是衡量图像质量的一个好的指标,因为它受到生成样本多样性的显著影响。表3显示了MS-COCO 2017 5K子集上的性能比较,图7显示了一些示例。尽管模型大小和所需的采样步骤数量都有明显减少,但SDXS-512的提示跟随能力仍然优于SD v1.5。与Tiny SD(另一个为效率而设计的模型)相比,SDXS-512的优越性更加明显。这一观察结果也在SDXS-1024的性能中得到了一致的验证。使用所提方法训练LoRA的样本如图9所示。显然,模型生成的图像风格可以有效地转移到与离线DM集成的风格导向LoRA匹配的风格,同时通常保持场景布局的一致性。

研究者引入的一步式训练方法是足够通用的,可以应用于图像条件生成。他们展示了其在促进图像到图像转换方面的有效性,特别是利用ControlNet进行涉及canny边缘和深度图的转换。图8展示了两个不同任务的代表性示例,突出了生成图像紧密遵循控制图像提供的指导的能力。然而,这也揭示了在图像多样性方面的显著局限性。如图1所示,虽然问题可以通过替换提示来缓解,但它仍然是后续研究工作中加强的领域。

实验证明将高效的图像条件生成部署在边缘设备上是一个充满前景的研究方向,研究者计划在未来探索包括修复和超分辨率在内的更多应用。通过不断的技术创新和优化,人工智能在图像生成领域的应用将更加广泛和深入。

论文链接:https://arxiv.org/abs/2403.16627

项目地址:https://idkiro.github.io/sdxs/

这篇关于SDXS:知识蒸馏在高效图像生成中的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1097483

相关文章

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

MySQL重复数据处理的七种高效方法

《MySQL重复数据处理的七种高效方法》你是不是也曾遇到过这样的烦恼:明明系统测试时一切正常,上线后却频频出现重复数据,大批量导数据时,总有那么几条不听话的记录导致整个事务莫名回滚,今天,我就跟大家分... 目录1. 重复数据插入问题分析1.1 问题本质1.2 常见场景图2. 基础解决方案:使用异常捕获3.

Java中的Lambda表达式及其应用小结

《Java中的Lambda表达式及其应用小结》Java中的Lambda表达式是一项极具创新性的特性,它使得Java代码更加简洁和高效,尤其是在集合操作和并行处理方面,:本文主要介绍Java中的La... 目录前言1. 什么是Lambda表达式?2. Lambda表达式的基本语法例子1:最简单的Lambda表

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB

Python结合PyWebView库打造跨平台桌面应用

《Python结合PyWebView库打造跨平台桌面应用》随着Web技术的发展,将HTML/CSS/JavaScript与Python结合构建桌面应用成为可能,本文将系统讲解如何使用PyWebView... 目录一、技术原理与优势分析1.1 架构原理1.2 核心优势二、开发环境搭建2.1 安装依赖2.2 验

Java字符串操作技巧之语法、示例与应用场景分析

《Java字符串操作技巧之语法、示例与应用场景分析》在Java算法题和日常开发中,字符串处理是必备的核心技能,本文全面梳理Java中字符串的常用操作语法,结合代码示例、应用场景和避坑指南,可快速掌握字... 目录引言1. 基础操作1.1 创建字符串1.2 获取长度1.3 访问字符2. 字符串处理2.1 子字

IDEA自动生成注释模板的配置教程

《IDEA自动生成注释模板的配置教程》本文介绍了如何在IntelliJIDEA中配置类和方法的注释模板,包括自动生成项目名称、包名、日期和时间等内容,以及如何定制参数和返回值的注释格式,需要的朋友可以... 目录项目场景配置方法类注释模板定义类开头的注释步骤类注释效果方法注释模板定义方法开头的注释步骤方法注

Python如何自动生成环境依赖包requirements

《Python如何自动生成环境依赖包requirements》:本文主要介绍Python如何自动生成环境依赖包requirements问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录生成当前 python 环境 安装的所有依赖包1、命令2、常见问题只生成当前 项目 的所有依赖包1、

SpringShell命令行之交互式Shell应用开发方式

《SpringShell命令行之交互式Shell应用开发方式》本文将深入探讨SpringShell的核心特性、实现方式及应用场景,帮助开发者掌握这一强大工具,具有很好的参考价值,希望对大家有所帮助,如... 目录引言一、Spring Shell概述二、创建命令类三、命令参数处理四、命令分组与帮助系统五、自定

SpringBoot应用中出现的Full GC问题的场景与解决

《SpringBoot应用中出现的FullGC问题的场景与解决》这篇文章主要为大家详细介绍了SpringBoot应用中出现的FullGC问题的场景与解决方法,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录Full GC的原理与触发条件原理触发条件对Spring Boot应用的影响示例代码优化建议结论F