结合代码详细讲解DDPM的训练和采样过程

2024-08-30 23:12

本文主要是介绍结合代码详细讲解DDPM的训练和采样过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

请添加图片描述

1. 训练(扩散)过程

首先我们来逐个看一下训练过程中的所有符号的含义:

x 0 x_0 x0是真实图像;

t 是扩散的步数,取值范围从1到T;

ϵ \epsilon ϵ是从标准正态分布中采样的噪声;

ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;

x t x_t xt的表达式如下:

在这里插入图片描述

x t x_t xt x 0 x_0 x0加噪获得,其中 α t ‾ \overline{\alpha_{t}} αt是常数
因此训练过程总结成一句话就是,向真实图像 x 0 x_0 x0中加噪,获得加噪后的图像 x t x_t xt;然后将 x t x_t xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向 x 0 x_0 x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示 x t x_t xt;而预测噪声的过程则是从 x t x_t xt解码的过程,只是并没有选择直接解码出 x 0 x_0 x0,而是解码出加入的噪声,也就是残差。请添加图片描述

下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt 。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:

[图片]

所以很容易理解代码中的sqrt_alphas_bar就是 α t ‾ \sqrt{\overline{\alpha_{t}}} αt ,sqrt_one_minus_alphas_bar 就是 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和 x 0 x_0 x0一致,接着获取 x t x_t xt

x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:

self.model(x_t, t)

计算Loss函数:

loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

训练过程的完整代码:

class GaussianDiffusionTrainer(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))# 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。def forward(self, x_0):"""Algorithm 1."""t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss

2. 推理(反向)过程

首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像 x T x_T xT中恢复出一张图像,表示成 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),我们没法推导出 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),但是 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt1xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把 x 0 x_0 x0化简掉。最终 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布的均值为:
请添加图片描述

方差为 β t \beta_t βt
因此我们可以从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布中采样出一个 x t − 1 x_{t-1} xt1
请添加图片描述
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
在这里插入图片描述
注意:是标准差与标准正态分布相乘,而不是方差;

因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。

注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。

# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]self.register_buffer('coeff1', torch.sqrt(1. / alphas))self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))def predict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shapereturn (extract(self.coeff1, t, x_t.shape) * x_t -extract(self.coeff2, t, x_t.shape) * eps)def p_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computationsvar = torch.cat([self.posterior_var[1:2], self.betas[1:]])var = extract(var, t, x_t.shape)eps = self.model(x_t, t)xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, vardef forward(self, x_T):"""Algorithm 2."""x_t = x_T # 输入是一个标准正态分布噪声# 从T到1进行reverse过程for time_step in reversed(range(self.T)):print(time_step)t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_stepmean, var= self.p_mean_variance(x_t=x_t, t=t) # no noise when t == 0if time_step > 0:noise = torch.randn_like(x_t)else:noise = 0x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."x_0 = x_treturn torch.clip(x_0, -1, 1)

这篇关于结合代码详细讲解DDPM的训练和采样过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1122208

相关文章

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

如何为Yarn配置国内源的详细教程

《如何为Yarn配置国内源的详细教程》在使用Yarn进行项目开发时,由于网络原因,直接使用官方源可能会导致下载速度慢或连接失败,配置国内源可以显著提高包的下载速度和稳定性,本文将详细介绍如何为Yarn... 目录一、查询当前使用的镜像源二、设置国内源1. 设置为淘宝镜像源2. 设置为其他国内源三、还原为官方

最详细安装 PostgreSQL方法及常见问题解决

《最详细安装PostgreSQL方法及常见问题解决》:本文主要介绍最详细安装PostgreSQL方法及常见问题解决,介绍了在Windows系统上安装PostgreSQL及Linux系统上安装Po... 目录一、在 Windows 系统上安装 PostgreSQL1. 下载 PostgreSQL 安装包2.

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

MySql match against工具详细用法

《MySqlmatchagainst工具详细用法》在MySQL中,MATCH……AGAINST是全文索引(Full-Textindex)的查询语法,它允许你对文本进行高效的全文搜素,支持自然语言搜... 目录一、全文索引的基本概念二、创建全文索引三、自然语言搜索四、布尔搜索五、相关性排序六、全文索引的限制七

python中各种常见文件的读写操作与类型转换详细指南

《python中各种常见文件的读写操作与类型转换详细指南》这篇文章主要为大家详细介绍了python中各种常见文件(txt,xls,csv,sql,二进制文件)的读写操作与类型转换,感兴趣的小伙伴可以跟... 目录1.文件txt读写标准用法1.1写入文件1.2读取文件2. 二进制文件读取3. 大文件读取3.1

Python结合PyWebView库打造跨平台桌面应用

《Python结合PyWebView库打造跨平台桌面应用》随着Web技术的发展,将HTML/CSS/JavaScript与Python结合构建桌面应用成为可能,本文将系统讲解如何使用PyWebView... 目录一、技术原理与优势分析1.1 架构原理1.2 核心优势二、开发环境搭建2.1 安装依赖2.2 验

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析