【扩散模型】5、Improved DDPM | 引入可学习方差和余弦加噪机制来提升 DDPM

本文主要是介绍【扩散模型】5、Improved DDPM | 引入可学习方差和余弦加噪机制来提升 DDPM,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

文章目录

    • 一、背景
    • 二、Improved DDPM——提升 Log-likelihood
      • 2.1 可学习的方差
      • 2.2 改进 noise schedule
      • 2.3 降低梯度噪声
    • 三、效果

论文:Improved Denoising Diffusion Probabilistic Models

代码:https://link.zhihu.com/?target=https%3A//github.com/openai/improved-diffusion

时间:2021.02.18

Improved DDPM 贡献:

  • 学习方差会让生成效果更好(DDPM 中只学习了均值,方差是一个常数)
  • 提出了余弦加噪方法,比线性加噪效果更好

一、背景

首先回顾一下 DDPM

前向传播过程:

  • 通过给输入 x 0 x_0 x0 进行 t t t 次加噪 β t ∈ ( 0 , 1 ) \beta_t \in (0,1) βt(0,1),得到最终的 x t x_t xt

    在这里插入图片描述

  • 假设给定一个足够大的 T T T 和一个变化规则良好的 β t \beta_t βt,则 x T x_T xT 就近似一个各向同性高斯分布。

  • 假设已知 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),就是能直接从 x t x_t xt 推出 x t − 1 x_{t-1} xt1,那么就能一路反推得到 q ( x 0 ) q(x_0) q(x0),从而采样出 x 0 x_0 x0,但是没有办法直接推出来,所以只能使用神经网络来估计出来每次反推的结果:

    在这里插入图片描述

  • 将 q 和 p 结合起来就是一个变分自编码器,可将变分下界(variational lower bound, VLB)写成如下形式:

    在这里插入图片描述

  • 公式 4 中,除了 L0 以外,其他每项都是两个高斯分布的 KL 散度

  • x 0 x_0 x0 可以直接得到 x t x_t xt,且边界分布如下,噪声的系数是方差,可以用这个系数来描述噪声的 schedule

    在这里插入图片描述
    在这里插入图片描述

  • 基于贝叶斯理论,可以计算后验分布如下:

    在这里插入图片描述

实际训练过程:

  • 目标函数 4 是多个独立项之和,每一项 L t − 1 L_{t-1} Lt1 基本都是真实噪声和预测噪声的 KL 散度

  • 怎么预测噪声均值 μ θ \mu_{\theta} μθ 呢,之前的方法大都是直接使用神经网络来预测,还有一种方法是通过预测 x 0 x_0 x0,然后基于公式 11 来预测。此外,还能通过使用公式 9 和 11 来得到:

    在这里插入图片描述

  • DDPM 中发现预测噪声能做的比较好,尤其是使用 reweighted loss 函数,下面的函数 14 可以看做从公式 4 中重加权得到的,且发现直接优化下面的公式 14 比优化 4 更好:

    在这里插入图片描述

二、Improved DDPM——提升 Log-likelihood

尽管 DDPM 在 FID 和 Inception Score 上获得很很好的效果,但在 Log-likelihood 上没有得到很高的得分

Log-likelihood 也是生成式任务上一个很重要的衡量指标,一般认为优化 Log-likelihood 能够让生成式模型捕捉数据分布的整体信息,所以,探索 DDPM 为什么在 Log-likelihood 上表现的不好还是很重要的

其理论出处文中给的是 VQ-VAE2:

在这里插入图片描述

2.1 可学习的方差

DDPM 在优化 L s a m p l e L_{sample} Lsample 的时候,设置的固定的方差 σ t 2 I \sigma_t^2I σt2I,方差是没有学习的,当 σ t 2 = β t \sigma_t^2=\beta_t σt2=βt σ t 2 = β ˜ t \sigma_t^2=\~{\beta}_t σt2=β˜t 时,采样质量没什么差别。

所以 DDPM 设置的 σ t 2 = β t \sigma_t^2=\beta_t σt2=βt ,T=1000 的情况下,在 ImageNet 64x64 上训练 200k iter 时, log-likelihood = 3.99。

本文作者尝试将 T=4000 时,log-likelihood 提升到了 3.77。

将固定方差变成可学习的方差:

  • 在 DDPM 中, ∑ θ ( x t , t ) = σ t 2 I \sum_{\theta}(x_t,t)=\sigma_t^2I θ(xt,t)=σt2I,其中 σ t \sigma_t σt 是不可学习的,是固定成了 σ t = β t \sigma_t=\beta_t σt=βt,且和 σ t 2 = β ˜ t \sigma_t^2=\~{\beta}_t σt2=β˜t 时的采样效果没什么大的差别

  • 一般来说, β t \beta_t βt β ˜ t \~{\beta}_t β˜t 表示了两种相反的极端,但为什么这种选择不会影响采样结果呢。如图 1 所示,展示了两者相除的结果,可以看出 β t \beta_t βt β ˜ t \~{\beta}_t β˜t 除了在 t=0 附近不太相同以外,在后面的部分相除的结果都接近于 1,且随着 T 的增大,这两者更加接近。这就说明在无限增大扩散步骤时, σ t \sigma_t σt 的选择对采样质量影响不大。也就是在使用更多的扩散步骤时,模型的平均值 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t) 比方差 ∑ θ ( x t , t ) \sum_{\theta}(x_t,t) θ(xt,t) 更能决定这个分布。

    在这里插入图片描述

  • Improved DDPM 想如何改进:本文作者认为,虽然 DDPM 中证明了固定的 σ t \sigma_t σt 基本上不会影响采样的效果,但没说不会影响 log-likelihood 啊!所以,Improved DDPM 作者觉得可能会影响 log-likelihood,于是就在图 2 中展示了扩散模型的前几个 step 对变分下界的影响,而且发现了前几个 step 对变分下届的贡献最大,所以,似乎可以通过选择更好的 ∑ θ ( x t , t ) \sum_{\theta}(x_t,t) θ(xt,t) 来提高 log-likelihood,所以,Improved DDPM 选择了学习 ∑ θ ( x t , t ) \sum_{\theta}(x_t,t) θ(xt,t),而非固定的模式。

    在这里插入图片描述

如何学习 ∑ θ ( x t , t ) \sum_{\theta}(x_t,t) θ(xt,t)

  • 如图 1 所示, ∑ θ ( x t , t ) \sum_{\theta}(x_t,t) θ(xt,t) 的变化范围很小,所以很难直接使用神经网络来预测这个值

  • 本文作者发现将其参数化为在 β t \beta_t βt β ˜ t \~{\beta}_t β˜t 在 log domain 之间的插值,也就是说模型输出一个向量 v v v,每个维度包含一个元素,使用如下的方式将输出变成方差:
    在这里插入图片描述

  • 而且没有对 v v v 进行额外的约束,但其也不会越界。所以最终的目标函数如下,且 λ = 0.001 \lambda=0.001 λ=0.001

    在这里插入图片描述

2.2 改进 noise schedule

在 DDPM 中使用的是线性加噪的方式,在高分辨率的图上表现的较好,但对 64x64 和 32x32 的图来说,并非最优的。

前向加噪过程是随机的,且对后面的采样过程也不很重要。加噪过程如图 3 所示。

在这里插入图片描述

影响如图 4 所示,当跳过 20% 的反向过程时,使用线性加噪规则训练的模型(橘色)也不会变得更糟(使用 FID 衡量)。

在这里插入图片描述

因此,本文作者提出了余弦加噪方式:

在这里插入图片描述

在这里插入图片描述

  • 这里使用的偏移 s 很小,是为了在 t=0 附近让 β t \beta_t βt 更小
  • 因为作者发现,在开始的时候噪声小的话,无法让网络很准确的预测 ϵ \epsilon ϵ,所以 s=0.008.
  • 作者使用 c o s 2 cos^2 cos2 的原因是它是一个常见的期望形状的数学函数,选择也是任意的。

余弦加噪的特点:

  • 在中间过程优一个线性的下降
  • 在 t=0 和 t=T 附近,变化很小

线性加噪的特点:

  • 下降到 0 的速度更快,所以破坏信息的速度更快

在这里插入图片描述

2.3 降低梯度噪声

本文是为了通过直接优化 L v l b L_{vlb} Lvlb 来得到最好的 log-likelihood,而不是优化 L h y b r i d L_{hybrid} Lhybrid

然而,作者发现 L v l b L_{vlb} Lvlb 实际上很难直接优化,至少在变化多样的 ImageNet 64x64 上很难优化。

如图 6 展示了 L v l b L_{vlb} Lvlb L h y b r i d L_{hybrid} Lhybrid 的学习曲线,两个曲线都很 noisy,就是不稳定,波动很大,但是橘色的 L h y b r i d L_{hybrid} Lhybrid 在同样训练步数的情况下的效果是更好一些的。

在这里插入图片描述

作者假设 L v l b L_{vlb} Lvlb 的梯度比 L h y b r i d L_{hybrid} Lhybrid 更 noisy,且通过衡量其梯度的 noisy scales 确定了这一点,如图 7 所示,所以,作者找到了一种降低 L v l b L_{vlb} Lvlb 方差的方法来直接优化 log-likelihood

在这里插入图片描述

如图 2 所示, L v l b L_{vlb} Lvlb 的不太项有不同的模值,所以假设采样 t 会在 L v l b L_{vlb} Lvlb 目标函数中带来均匀的噪声,所以作者使用了 importance sampling :

在这里插入图片描述

  • 由于 E [ L t 2 ] E[L_t^2] E[Lt2] 是事先不知道的,也会在训练的时候改变,所以会保留前 10 次的值,且在训练的时候动态更新。

有了这个 importance sampling 方法,就能够通过优化 L v l b L_{vlb} Lvlb 来实现最佳的 log-likelihood。如图 6,而且 importance sampling 的噪声比原始均匀采样的目标函数小得多。

三、效果

在这里插入图片描述

在这里插入图片描述

这篇关于【扩散模型】5、Improved DDPM | 引入可学习方差和余弦加噪机制来提升 DDPM的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/633450

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss