【论文笔记】Scalable Diffusion Models with State Space Backbone

2024-03-09 17:28

本文主要是介绍【论文笔记】Scalable Diffusion Models with State Space Backbone,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

原文链接:https://arxiv.org/abs/2402.05608

1. 引言

主干网络是扩散模型发展的关键方面,其中基于CNN的U-Net(下采样-跳跃连接-上采样)和基于Transformer的结构(使用自注意力替换采样块)是代表性的例子。

状态空间模型(SSM)在长序列建模方面有极大潜力。本文受Mamba启发,建立基于SSM的扩散模型,称为DiS。DiS将所有输入(时间、条件和有噪声的图像patch)视为离散token。DiS中的状态空间模型使其比CNN和Transformer有更优的放缩性,且有更低的计算开销。

2. 方法

2.1 准备知识

扩散模型:扩散模型逐步向数据加入噪声,然后将此过程反过来从噪声生成数据。噪声的加入过程称为前向过程,可表达为马尔科夫链。逆过程中,使用高斯模型近似真实逆转移,其中学习相当于对噪声的预测(即使用噪声预测网络,来最小化噪声预测目标)。

条件扩散模型会将条件(如类别、文本等,通常形式为索引或连续嵌入)引入噪声预测目标中。

具体公式见扩散模型(Diffusion Model)简介 - CSDN。

状态空间主干:状态空间模型的传统定义是将 x ( t ) ∈ R N x(t)\in\mathbb R^N x(t)RN通过隐状态 h ( t ) ∈ R N h(t)\in\mathbb R^N h(t)RN映射为 y ( t ) ∈ R N y(t)\in\mathbb R^N y(t)RN的线性时不变系统:
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) h'(t)=Ah(t)+Bx(t)\\y(t)=Ch(t) h(t)=Ah(t)+Bx(t)y(t)=Ch(t)

其中 A ∈ R N × N A\in\mathbb R^{N\times N} ARN×N为状态矩阵, B , C ∈ R N B,C\in\mathbb R^N B,CRN为输入和输出矩阵。真实世界的数据通常为离散形式,可将上式离散化为
h t = A ˉ h t − 1 + B ˉ x t y t = C h t h_t=\bar Ah_{t-1}+\bar Bx_t\\y_t=Ch_t ht=Aˉht1+Bˉxtyt=Cht

其中 A ˉ = exp ⁡ ( Δ ⋅ A ) , B ˉ = ( Δ ⋅ A ) − 1 ( exp ⁡ ( Δ ⋅ A ) − I ) ⋅ ( Δ B ) \bar A=\exp(\Delta\cdot A),\bar B=(\Delta\cdot A)^{-1}(\exp(\Delta\cdot A)-I)\cdot(\Delta B) Aˉ=exp(ΔA),Bˉ=(ΔA)1(exp(ΔA)I)(ΔB)为离散状态参数, Δ \Delta Δ为离散步长。

虽然SSM理论上性质优良,但通常有高计算量和数值不稳定性。结构状态空间模型(S4)通过强制 A A A的形式来减轻这一问题,能达到比Transformer更高的性能;Mamba则进一步通过输入依赖的选择机制和更快的硬件感知算法改进之。

2.2 模型结构设计

DiS参数化噪声预测网络 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t,t,c) ϵθ(xt,t,c),以时间 t t t、条件 c c c和噪声图像 x t x_t xt,预测向 x t x_t xt加入的噪声。DiS基于双向Mamba结构,如下图所示。
在这里插入图片描述
图像patch化:DiS的第一层将输入图像 I ∈ R H × W × C I\in\mathbb R^{H\times W\times C} IRH×W×C转化为拉直的2D patch X ∈ R J × ( p 2 ⋅ C ) X\in\mathbb R^{J\times (p^2\cdot C)} XRJ×(p2C)。然后,通过对每个patch进行线性嵌入,转化为含 J J J个token的、维度为 D D D的序列。为每个输入token使用可学习位置编码。 J = H × W p 2 J=\frac{H\times W}{p^2} J=p2H×W由patch大小 p p p决定。

SSM块:输入token会被一组SSM块处理。SSM块的输入还包括时间 t t t与条件 c c c。本文使用双向序列建模,即SSM块的前向过程包含了前向和反向两个方向的处理。

跳跃连接:本文将 L L L个SSM块分为前半和后半两部分,每部分 ⌊ L 2 ⌋ \lfloor\frac L2\rfloor 2L个。设 h s h a l l o w , h d e e p ∈ R J × D h_{shallow},h_{deep}\in\mathbb{R}^{J\times D} hshallow,hdeepRJ×D分别为跳跃连接分支和主分支的隐状态,则通过拼接和线性投影后再送入下一个SSM块,即 L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))

线性解码器:需要将最后一个SSM块的隐状态解码为噪声预测和对角化协方差矩阵(与原始输入尺寸相同)。本文使用线性解码器,即LayerNorm+线性层,将每个token转化为 p 2 ⋅ C p^2\cdot C p2C的张量。最后,将解码的token重排为原始大小,得到预测噪声与协方差。

条件引入:本文在输入token的序列上增加时间 t t t与条件 c c c的向量嵌入作为额外token(类似ViT中的类别token),从而无需修改SSM块。在最后一个SSM块后,从序列移除条件token。此外,还用自适应归一化层替换标准归一化层,使模型从 c c c t t t嵌入向量的和中回归缩放和偏移参数。

2.3 计算分析

对序列 X ∈ R 1 × J × D X\in\mathbb R^{1\times J\times D} XR1×J×D和状态扩维默认设置 E = 2 E=2 E=2,自注意力与SSM的计算复杂度分别为 O ( S A ) = 4 J D 2 + 2 J 2 D O(SA)=4JD^2+2J^2D O(SA)=4JD2+2J2D O ( S S M ) = 3 J ( 2 D ) N + J ( 2 D ) N 2 O(SSM)=3J(2D)N+J(2D)N^2 O(SSM)=3J(2D)N+J(2D)N2

其中自注意力的计算是序列长度 J J J的二次方,而SSM则是线性关系。注意 N N N为固定参数。这说明DiS有较强的可放缩性。

3. 实验

3.1 实验设置

数据集:仅使用水平翻转数据增广。

实施细节:本文对DiS的权重使用指数移动平均方法。

3.2 模型分析

patch大小的影响:当模型大小一致时,减小patch大小(增加token数),性能会提高。这可能是扩散模型噪声预测任务的低级特性,导致需要小型patch,而不像更高级的分类任务。对高分辨率图像,使用小尺寸patch可能会引入高计算成本,可将图像转换为低维隐式表达,然后再使用DiS处理。

长跳跃的影响:比较拼接( L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))
、求和( h s h a l l o w + h d e e p h_{shallow}+h_{deep} hshallow+hdeep)和无跳跃连接三种方式。实验表明,求和不会带来明显的性能提升,因为SSM自身可以通过线性方式保留一些浅层信息。而使用拼接和可学习的线性投影可以大幅增加性能。

条件组合:比较两种引入时间 t t t的方案:(1)将 t t t视为token,与图像patch一同处理;(2)将 t t t的嵌入整合到SSM块的层归一化中,类似U-Net中的自适应分组归一化,得到自适应层归一化: A d a L N ( h , s ) = y s L a y e r N o r m ( h ) + y b AdaLN(h,s)=y_s\mathtt{LayerNorm}(h)+y_b AdaLN(h,s)=ysLayerNorm(h)+yb,其中 h h h为SSM的隐状态, y s , y b y_s,y_b ys,yb为时间嵌入的线性投影。实验表明前者的性能优于后者。

缩放模型大小:增大模型深度(SSM块层数)和宽度(隐状态维度)均能提高性能。

3.3 主要结果

无条件图像生成:DiS与基于U-Net或Transformer的扩散模型有相当的性能,但参数量更少。

以类别为条件的图像生成:本文的方法可以超过其余方法的性能。

这篇关于【论文笔记】Scalable Diffusion Models with State Space Backbone的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/791384

相关文章

鸿蒙中@State的原理使用详解(HarmonyOS 5)

《鸿蒙中@State的原理使用详解(HarmonyOS5)》@State是HarmonyOSArkTS框架中用于管理组件状态的核心装饰器,其核心作用是实现数据驱动UI的响应式编程模式,本文给大家介绍... 目录一、@State在鸿蒙中是做什么的?二、@Spythontate的基本原理1. 依赖关系的收集2.

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

idea maven编译报错Java heap space的解决方法

《ideamaven编译报错Javaheapspace的解决方法》这篇文章主要为大家详细介绍了ideamaven编译报错Javaheapspace的相关解决方法,文中的示例代码讲解详细,感兴趣的... 目录1.增加 Maven 编译的堆内存2. 增加 IntelliJ IDEA 的堆内存3. 优化 Mave

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个