Open-Sora代码详细解读(1):解读DiT结构

2024-09-08 02:12

本文主要是介绍Open-Sora代码详细解读(1):解读DiT结构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Diffusion Models专栏文章汇总:入门与实战

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

DiT相比于Unet的关键改进点

Token化方法

因果3D卷积

Adaptive Layer Norm (adaLN) block 

完整DiT Block 设计


DiT相比于Unet的关键改进点

虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。

采用DiT架构替换UNet主要需要探索以下几个关键问题:

  1. Token化处理。Transformer的输入为一维序列,形式为𝑅𝑇×𝑑RT×d(忽略batch维度),而LDM的latent表征𝑧∈𝑅𝐻𝑓×𝑊𝑓×𝐶z∈RfH​×fW​×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。
  2. 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代Unet需要系统研究Transformer架构的条件嵌入

Token化方法

假定原始图片𝑥∈𝑅256×256×3,经过auto-encoder后得到latent表征𝑧∈𝑅32×32×4。首先DiT 用ViT中patch化的方式将隐表征𝑧转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p是一个超参数。

刚才是DiT原始论文的描述,在视频里用了一个PatchEmbed3D 执行Token化:

class PatchEmbed3D(nn.Module):"""Video to Patch Embedding.Args:patch_size (int): Patch token size. Default: (2,4,4).in_chans (int): Number of input video channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self,patch_size=(2, 4, 4),in_chans=3,embed_dim=96,norm_layer=None,flatten=True,):super().__init__()self.patch_size = patch_sizeself.flatten = flattenself.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):"""Forward function."""# padding_, _, D, H, W = x.size()if W % self.patch_size[2] != 0:x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))if H % self.patch_size[1] != 0:x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))if D % self.patch_size[0] != 0:x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))x = self.proj(x)  # (B C T H W)if self.norm is not None:D, Wh, Ww = x.size(2), x.size(3), x.size(4)x = x.flatten(2).transpose(1, 2)x = self.norm(x)x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNCreturn x

先把视频的长宽和时间长都填充成偶数,然后用一个3D卷积,把时间、空间都进一步压缩,Channel从4膨胀到96,然后把时空都压缩到一起,即:

x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNC

因果3D卷积

刚才Token化用的是普通的3D卷积,其他有些代码里用了因果3D卷积,因果3D卷积在视频任务里非常常用:

因果3D卷积(Causal 3D Convolution)是一种特殊的3D卷积,它在处理具有时间维度的数据(如视频)时保持因果性。这意味着在生成当前时间点的输出时,它只依赖于当前和之前的时间点,而不依赖于未来的时间点。卷积核在时间维度上滑动,它也只会接触到当前和过去的帧。这在序列建模和时间序列预测等任务中非常重要,因为它们需要保证模型输出的因果关系。

与传统的3D卷积相比,因果3D卷积在时间维度上增加了填充(padding),以确保输出的时间长度与输入相同。这种填充通常是在时间维度的开始处添加,而不是在两端添加,这样可以保证在预测当前帧时不会使用到后续帧的信息。通过在时间轴的正方向上(即未来的方向)添加适当的零填充来实现这一点。

下面是EasyAnimate的实现代码:

class CausalConv3d(nn.Conv3d):def __init__(self,in_channels: int,out_channels: int,kernel_size=3, # : int | tuple[int, int, int], stride=1, # : int | tuple[int, int, int] = 1,padding=1, # : int | tuple[int, int, int],  # TODO: change it to 0.dilation=1, # :  int | tuple[int, int, int] = 1,**kwargs,):kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."stride = stride if isinstance(stride, tuple) else (stride,) * 3assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."t_ks, h_ks, w_ks = kernel_size_, h_stride, w_stride = stridet_dilation, h_dilation, w_dilation = dilationt_pad = (t_ks - 1) * t_dilation# TODO: align with SDif padding is None:h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)elif isinstance(padding, int):h_pad = w_pad = paddingelse:assert NotImplementedErrorself.temporal_padding = t_padself.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)self.padding_flag = 0super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,dilation=dilation,padding=(0, h_pad, w_pad),**kwargs,)def forward(self, x: torch.Tensor) -> torch.Tensor:# x: (B, C, T, H, W)if self.padding_flag == 0:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding, 0),mode="replicate",     # TODO: check if this is necessary)else:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),)return super().forward(x)def set_padding_one_frame(self):def _set_padding_one_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 1for sub_name, sub_mod in module.named_children():_set_padding_one_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_one_frame(name, module)def set_padding_more_frame(self):def _set_padding_more_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 2for sub_name, sub_mod in module.named_children():_set_padding_more_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_more_frame(name, module)

Adaptive Layer Norm (adaLN) block 

这是DiT里面最核心的设计之一,adaptive normalization layer(adaLN),将transformer block的layer norm替换为adaLN。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter 𝛾和shift parameter 𝛽 用condition embedding来替代。

原始的Layer Norm设计:

class LayerNorm:def __init__(self, feature_dim, epsilon=1e-6):self.epsilon = epsilonself.gamma = np.random.rand(feature_dim)  # scale parametersself.beta = np.random.rand(feature_dim)  # shift parametrsdef __call__(self, x: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)return:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.betareturn x_layer_norm

DiT中的adaLN设计:

class DiTBlock(nn.Module):"""A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""def __init__(self,hidden_size,num_heads,mlp_ratio=4.0,enable_flash_attn=False,enable_layernorm_kernel=False,):super().__init__()self.hidden_size = hidden_sizeself.num_heads = num_headsself.enable_flash_attn = enable_flash_attnmlp_hidden_dim = int(hidden_size * mlp_ratio)self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.attn = Attention(hidden_size,num_heads=num_heads,qkv_bias=True,enable_flash_attn=enable_flash_attn,)self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))def forward(self, x, c):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))return x

完整DiT Block 设计

好了,到这里已经是把主要的DiT构建出来了,接下来把DiT结构堆积28层,构成了现在的DiT结构:

@MODELS.register_module()
class DiT(nn.Module):"""Diffusion model with a Transformer backbone."""def __init__(self,input_size=(16, 32, 32),in_channels=4,patch_size=(1, 2, 2),hidden_size=1152,depth=28,num_heads=16,mlp_ratio=4.0,class_dropout_prob=0.1,learn_sigma=True,condition="text",no_temporal_pos_emb=False,caption_channels=512,model_max_length=77,dtype=torch.float32,enable_flash_attn=False,enable_layernorm_kernel=False,enable_sequence_parallelism=False,):super().__init__()self.learn_sigma = learn_sigmaself.in_channels = in_channelsself.out_channels = in_channels * 2 if learn_sigma else in_channelsself.hidden_size = hidden_sizeself.patch_size = patch_sizeself.input_size = input_sizenum_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])self.num_patches = num_patchesself.num_temporal = input_size[0] // patch_size[0]self.num_spatial = num_patches // self.num_temporalself.num_heads = num_headsself.dtype = dtypeself.use_text_encoder = not condition.startswith("label")if enable_flash_attn:assert dtype in [torch.float16,torch.bfloat16,], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"self.no_temporal_pos_emb = no_temporal_pos_embself.mlp_ratio = mlp_ratioself.depth = depthassert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)if not self.use_text_encoder:num_classes = int(condition.split("_")[-1])self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)else:self.y_embedder = CaptionEmbedder(in_channels=caption_channels,hidden_size=hidden_size,uncond_prob=class_dropout_prob,act_layer=approx_gelu,token_num=1,  # pooled token)self.t_embedder = TimestepEmbedder(hidden_size)self.blocks = nn.ModuleList([DiTBlock(hidden_size,num_heads,mlp_ratio=mlp_ratio,enable_flash_attn=enable_flash_attn,enable_layernorm_kernel=enable_layernorm_kernel,)for _ in range(depth)])self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)self.initialize_weights()self.enable_flash_attn = enable_flash_attnself.enable_layernorm_kernel = enable_layernorm_kerneldef get_spatial_pos_embed(self):pos_embed = get_2d_sincos_pos_embed(self.hidden_size,self.input_size[1] // self.patch_size[1],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef get_temporal_pos_embed(self):pos_embed = get_1d_sincos_pos_embed(self.hidden_size,self.input_size[0] // self.patch_size[0],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef unpatchify(self, x):c = self.out_channelst, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]pt, ph, pw = self.patch_sizex = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))x = rearrange(x, "n t h w r p q c -> n c t r h p w q")imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))return imgsdef forward(self, x, t, y):"""Forward pass of DiT.x: (B, C, T, H, W) tensor of inputst: (B,) tensor of diffusion timestepsy: list of text"""# origin inputs should be float32, cast to specified dtypex = x.to(self.dtype)if self.use_text_encoder:y = y.to(self.dtype)# embeddingx = self.x_embedder(x)  # (B, N, D)x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)x = x + self.pos_embed_spatialif not self.no_temporal_pos_emb:x = rearrange(x, "b t s d -> b s t d")x = x + self.pos_embed_temporalx = rearrange(x, "b s t d -> b (t s) d")else:x = rearrange(x, "b t s d -> b (t s) d")t = self.t_embedder(t, dtype=x.dtype)  # (N, D)y = self.y_embedder(y, self.training)  # (N, D)if self.use_text_encoder:y = y.squeeze(1).squeeze(1)condition = t + y# blocksfor _, block in enumerate(self.blocks):c = conditionx = auto_grad_checkpoint(block, x, c)  # (B, N, D)# final processx = self.final_layer(x, condition)  # (B, N, num_patches * out_channels)x = self.unpatchify(x)  # (B, out_channels, T, H, W)# cast to float32 for better accuracyx = x.to(torch.float32)return xdef initialize_weights(self):# Initialize transformer layers:def _basic_init(module):if isinstance(module, nn.Linear):if module.weight.requires_grad_:torch.nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)self.apply(_basic_init)# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):w = self.x_embedder.proj.weight.datann.init.xavier_uniform_(w.view([w.shape[0], -1]))nn.init.constant_(self.x_embedder.proj.bias, 0)# Initialize timestep embedding MLP:nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)# Zero-out adaLN modulation layers in DiT blocks:for block in self.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight, 0)nn.init.constant_(block.adaLN_modulation[-1].bias, 0)# Zero-out output layers:nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.final_layer.linear.weight, 0)nn.init.constant_(self.final_layer.linear.bias, 0)# Zero-out text embedding layers:if self.use_text_encoder:nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)

这篇关于Open-Sora代码详细解读(1):解读DiT结构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1146818

相关文章

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

usaco 1.3 Mixing Milk (结构体排序 qsort) and hdu 2020(sort)

到了这题学会了结构体排序 于是回去修改了 1.2 milking cows 的算法~ 结构体排序核心: 1.结构体定义 struct Milk{int price;int milks;}milk[5000]; 2.自定义的比较函数,若返回值为正,qsort 函数判定a>b ;为负,a<b;为0,a==b; int milkcmp(const void *va,c

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

MCU7.keil中build产生的hex文件解读

1.hex文件大致解读 闲来无事,查看了MCU6.用keil新建项目的hex文件 用FlexHex打开 给我的第一印象是:经过软件的解释之后,发现这些数据排列地十分整齐 :02000F0080FE71:03000000020003F8:0C000300787FE4F6D8FD75810702000F3D:00000001FF 把解释后的数据当作十六进制来观察 1.每一行数据

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

自定义类型:结构体(续)

目录 一. 结构体的内存对齐 1.1 为什么存在内存对齐? 1.2 修改默认对齐数 二. 结构体传参 三. 结构体实现位段 一. 结构体的内存对齐 在前面的文章里我们已经讲过一部分的内存对齐的知识,并举出了两个例子,我们再举出两个例子继续说明: struct S3{double a;int b;char c;};int mian(){printf("%zd\n",s

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1