深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程

2024-01-26 15:20

本文主要是介绍深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

         我在上篇博客深入浅出 diffusion(1):白话 diffusion 原理(无公式)中介绍了 diffusion 的一些基本原理,其中谈到了 diffusion 的加噪过程,本文用pytorch 实现下到底是怎么加噪的。

import torch
import math
import numpy as np
from PIL import Image
import requests
import matplotlib.pyplot as plot
import cv2def linear_beta_schedule(timesteps):"""linear schedule, proposed in original ddpm paper"""scale = 1000 / timestepsbeta_start = scale * 0.0001beta_end = scale * 0.02return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)def cosine_beta_schedule(timesteps, s = 0.008):"""cosine scheduleas proposed in https://openreview.net/forum?id=-NEXDKk8gZ"""steps = timesteps + 1t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timestepsalphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0, 0.999)# 时间步(timestep)定义为1000
timesteps = 1000# 定义Beta Schedule, 选择线性版本,同DDPM原文一致,当然也可以换成cosine_beta_schedule
betas = linear_beta_schedule(timesteps=timesteps)# 根据beta定义alpha 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)# 计算前向过程 diffusion q(x_t | x_{t-1}) 中所需的
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)# 前向加噪过程: forward diffusion process
def q_sample(x_start, t, noise=None):if noise is None:noise = torch.randn_like(x_start)cv2.imwrite('noise.png', noise.numpy()*255)sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)print('sqrt_alphas_cumprod_t :', sqrt_alphas_cumprod_t)print('sqrt_one_minus_alphas_cumprod_t :', sqrt_one_minus_alphas_cumprod_t)return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise# 图像后处理
def get_noisy_image(x_start, t):# add noisex_noisy = q_sample(x_start, t=t)# turn back into PIL imagenoisy_image = x_noisy.squeeze().numpy()return noisy_image...# 展示图像, t=0, 50, 100, 500的效果
x_start = cv2.imread('img.png') / 255.0
x_start = torch.tensor(x_start, dtype=torch.float)
cv2.imwrite('img_0.png', get_noisy_image(x_start, torch.tensor([0])) * 255.0)
cv2.imwrite('img_50.png', get_noisy_image(x_start, torch.tensor([50])) * 255.0)
cv2.imwrite('img_100.png', get_noisy_image(x_start, torch.tensor([100])) * 255.0)
cv2.imwrite('img_500.png', get_noisy_image(x_start, torch.tensor([500])) * 255.0)
cv2.imwrite('img_999.png', get_noisy_image(x_start, torch.tensor([999])) * 255.0)sqrt_alphas_cumprod_t : tensor([[[0.9999]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.0100]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.9849]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.1733]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.9461]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.3238]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.2789]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.9603]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.0064]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[1.0000]]], dtype=torch.float64)

        以下分别为原图,t = 0, 50, 100, 500, 999 的结果。

        可见,随着 t 的加大,原图对应的比例系数减小,噪声的强度系数加大,t = 500的时候,隐约可见人脸轮廓,t = 999 的时候,人脸彻底淹没在噪声里面了。

这篇关于深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/647227

相关文章

springboot filter实现请求响应全链路拦截

《springbootfilter实现请求响应全链路拦截》这篇文章主要为大家详细介绍了SpringBoot如何结合Filter同时拦截请求和响应,从而实现​​日志采集自动化,感兴趣的小伙伴可以跟随小... 目录一、为什么你需要这个过滤器?​​​二、核心实现:一个Filter搞定双向数据流​​​​三、完整代码

SpringBoot利用@Validated注解优雅实现参数校验

《SpringBoot利用@Validated注解优雅实现参数校验》在开发Web应用时,用户输入的合法性校验是保障系统稳定性的基础,​SpringBoot的@Validated注解提供了一种更优雅的解... 目录​一、为什么需要参数校验二、Validated 的核心用法​1. 基础校验2. php分组校验3

Python实现AVIF图片与其他图片格式间的批量转换

《Python实现AVIF图片与其他图片格式间的批量转换》这篇文章主要为大家详细介绍了如何使用Pillow库实现AVIF与其他格式的相互转换,即将AVIF转换为常见的格式,比如JPG或PNG,需要的小... 目录环境配置1.将单个 AVIF 图片转换为 JPG 和 PNG2.批量转换目录下所有 AVIF 图

Pydantic中model_validator的实现

《Pydantic中model_validator的实现》本文主要介绍了Pydantic中model_validator的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录引言基础知识创建 Pydantic 模型使用 model_validator 装饰器高级用法mo

AJAX请求上传下载进度监控实现方式

《AJAX请求上传下载进度监控实现方式》在日常Web开发中,AJAX(AsynchronousJavaScriptandXML)被广泛用于异步请求数据,而无需刷新整个页面,:本文主要介绍AJAX请... 目录1. 前言2. 基于XMLHttpRequest的进度监控2.1 基础版文件上传监控2.2 增强版多

Redis分片集群的实现

《Redis分片集群的实现》Redis分片集群是一种将Redis数据库分散到多个节点上的方式,以提供更高的性能和可伸缩性,本文主要介绍了Redis分片集群的实现,具有一定的参考价值,感兴趣的可以了解一... 目录1. Redis Cluster的核心概念哈希槽(Hash Slots)主从复制与故障转移2.

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

使用Python实现一键隐藏屏幕并锁定输入

《使用Python实现一键隐藏屏幕并锁定输入》本文主要介绍了使用Python编写一个一键隐藏屏幕并锁定输入的黑科技程序,能够在指定热键触发后立即遮挡屏幕,并禁止一切键盘鼠标输入,这样就再也不用担心自己... 目录1. 概述2. 功能亮点3.代码实现4.使用方法5. 展示效果6. 代码优化与拓展7. 总结1.

Mybatis 传参与排序模糊查询功能实现

《Mybatis传参与排序模糊查询功能实现》:本文主要介绍Mybatis传参与排序模糊查询功能实现,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、#{ }和${ }传参的区别二、排序三、like查询四、数据库连接池五、mysql 开发企业规范一、#{ }和${ }传参的

Docker镜像修改hosts及dockerfile修改hosts文件的实现方式

《Docker镜像修改hosts及dockerfile修改hosts文件的实现方式》:本文主要介绍Docker镜像修改hosts及dockerfile修改hosts文件的实现方式,具有很好的参考价... 目录docker镜像修改hosts及dockerfile修改hosts文件准备 dockerfile 文