结合代码详细讲解DDPM的训练和采样过程

2024-08-30 23:12

本文主要是介绍结合代码详细讲解DDPM的训练和采样过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

请添加图片描述

1. 训练(扩散)过程

首先我们来逐个看一下训练过程中的所有符号的含义:

x 0 x_0 x0是真实图像;

t 是扩散的步数,取值范围从1到T;

ϵ \epsilon ϵ是从标准正态分布中采样的噪声;

ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;

x t x_t xt的表达式如下:

在这里插入图片描述

x t x_t xt x 0 x_0 x0加噪获得,其中 α t ‾ \overline{\alpha_{t}} αt是常数
因此训练过程总结成一句话就是,向真实图像 x 0 x_0 x0中加噪,获得加噪后的图像 x t x_t xt;然后将 x t x_t xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向 x 0 x_0 x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示 x t x_t xt;而预测噪声的过程则是从 x t x_t xt解码的过程,只是并没有选择直接解码出 x 0 x_0 x0,而是解码出加入的噪声,也就是残差。请添加图片描述

下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt 。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:

[图片]

所以很容易理解代码中的sqrt_alphas_bar就是 α t ‾ \sqrt{\overline{\alpha_{t}}} αt ,sqrt_one_minus_alphas_bar 就是 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和 x 0 x_0 x0一致,接着获取 x t x_t xt

x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:

self.model(x_t, t)

计算Loss函数:

loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

训练过程的完整代码:

class GaussianDiffusionTrainer(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))# 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。def forward(self, x_0):"""Algorithm 1."""t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss

2. 推理(反向)过程

首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像 x T x_T xT中恢复出一张图像,表示成 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),我们没法推导出 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),但是 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt1xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把 x 0 x_0 x0化简掉。最终 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布的均值为:
请添加图片描述

方差为 β t \beta_t βt
因此我们可以从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布中采样出一个 x t − 1 x_{t-1} xt1
请添加图片描述
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
在这里插入图片描述
注意:是标准差与标准正态分布相乘,而不是方差;

因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。

注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。

# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]self.register_buffer('coeff1', torch.sqrt(1. / alphas))self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))def predict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shapereturn (extract(self.coeff1, t, x_t.shape) * x_t -extract(self.coeff2, t, x_t.shape) * eps)def p_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computationsvar = torch.cat([self.posterior_var[1:2], self.betas[1:]])var = extract(var, t, x_t.shape)eps = self.model(x_t, t)xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, vardef forward(self, x_T):"""Algorithm 2."""x_t = x_T # 输入是一个标准正态分布噪声# 从T到1进行reverse过程for time_step in reversed(range(self.T)):print(time_step)t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_stepmean, var= self.p_mean_variance(x_t=x_t, t=t) # no noise when t == 0if time_step > 0:noise = torch.randn_like(x_t)else:noise = 0x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."x_0 = x_treturn torch.clip(x_0, -1, 1)

这篇关于结合代码详细讲解DDPM的训练和采样过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1122208

相关文章

浅析Spring Security认证过程

类图 为了方便理解Spring Security认证流程,特意画了如下的类图,包含相关的核心认证类 概述 核心验证器 AuthenticationManager 该对象提供了认证方法的入口,接收一个Authentiaton对象作为参数; public interface AuthenticationManager {Authentication authenticate(Authenti

作业提交过程之HDFSMapReduce

作业提交全过程详解 (1)作业提交 第1步:Client调用job.waitForCompletion方法,向整个集群提交MapReduce作业。 第2步:Client向RM申请一个作业id。 第3步:RM给Client返回该job资源的提交路径和作业id。 第4步:Client提交jar包、切片信息和配置文件到指定的资源提交路径。 第5步:Client提交完资源后,向RM申请运行MrAp

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1

沁恒CH32在MounRiver Studio上环境配置以及使用详细教程

目录 1.  RISC-V简介 2.  CPU架构现状 3.  MounRiver Studio软件下载 4.  MounRiver Studio软件安装 5.  MounRiver Studio软件介绍 6.  创建工程 7.  编译代码 1.  RISC-V简介         RISC就是精简指令集计算机(Reduced Instruction SetCom