本文主要是介绍Open-Sora代码详细解读(1):解读DiT结构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Diffusion Models专栏文章汇总:入门与实战
前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。
目录
DiT相比于Unet的关键改进点
Token化方法
因果3D卷积
Adaptive Layer Norm (adaLN) block
完整DiT Block 设计
DiT相比于Unet的关键改进点
虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。
采用DiT
架构替换UNet
主要需要探索以下几个关键问题:
- Token化处理。Transformer的输入为一维序列,形式为𝑅𝑇×𝑑RT×d(忽略batch维度),而
LDM
的latent表征𝑧∈𝑅𝐻𝑓×𝑊𝑓×𝐶z∈RfH×fW×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。 - 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代
Unet
需要系统研究Transformer架构的条件嵌入
Token化方法
假定原始图片𝑥∈𝑅256×256×3,经过auto-encoder
后得到latent表征𝑧∈𝑅32×32×4。首先DiT
用ViT中patch化的方式将隐表征𝑧转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p
是一个超参数。
刚才是DiT原始论文的描述,在视频里用了一个PatchEmbed3D 执行Token化:
class PatchEmbed3D(nn.Module):"""Video to Patch Embedding.Args:patch_size (int): Patch token size. Default: (2,4,4).in_chans (int): Number of input video channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self,patch_size=(2, 4, 4),in_chans=3,embed_dim=96,norm_layer=None,flatten=True,):super().__init__()self.patch_size = patch_sizeself.flatten = flattenself.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):"""Forward function."""# padding_, _, D, H, W = x.size()if W % self.patch_size[2] != 0:x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))if H % self.patch_size[1] != 0:x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))if D % self.patch_size[0] != 0:x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))x = self.proj(x) # (B C T H W)if self.norm is not None:D, Wh, Ww = x.size(2), x.size(3), x.size(4)x = x.flatten(2).transpose(1, 2)x = self.norm(x)x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)if self.flatten:x = x.flatten(2).transpose(1, 2) # BCTHW -> BNCreturn x
先把视频的长宽和时间长都填充成偶数,然后用一个3D卷积,把时间、空间都进一步压缩,Channel从4膨胀到96,然后把时空都压缩到一起,即:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
因果3D卷积
刚才Token化用的是普通的3D卷积,其他有些代码里用了因果3D卷积,因果3D卷积在视频任务里非常常用:
因果3D卷积(Causal 3D Convolution)是一种特殊的3D卷积,它在处理具有时间维度的数据(如视频)时保持因果性。这意味着在生成当前时间点的输出时,它只依赖于当前和之前的时间点,而不依赖于未来的时间点。卷积核在时间维度上滑动,它也只会接触到当前和过去的帧。这在序列建模和时间序列预测等任务中非常重要,因为它们需要保证模型输出的因果关系。
与传统的3D卷积相比,因果3D卷积在时间维度上增加了填充(padding),以确保输出的时间长度与输入相同。这种填充通常是在时间维度的开始处添加,而不是在两端添加,这样可以保证在预测当前帧时不会使用到后续帧的信息。通过在时间轴的正方向上(即未来的方向)添加适当的零填充来实现这一点。
下面是EasyAnimate的实现代码:
class CausalConv3d(nn.Conv3d):def __init__(self,in_channels: int,out_channels: int,kernel_size=3, # : int | tuple[int, int, int], stride=1, # : int | tuple[int, int, int] = 1,padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.dilation=1, # : int | tuple[int, int, int] = 1,**kwargs,):kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."stride = stride if isinstance(stride, tuple) else (stride,) * 3assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."t_ks, h_ks, w_ks = kernel_size_, h_stride, w_stride = stridet_dilation, h_dilation, w_dilation = dilationt_pad = (t_ks - 1) * t_dilation# TODO: align with SDif padding is None:h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)elif isinstance(padding, int):h_pad = w_pad = paddingelse:assert NotImplementedErrorself.temporal_padding = t_padself.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)self.padding_flag = 0super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,dilation=dilation,padding=(0, h_pad, w_pad),**kwargs,)def forward(self, x: torch.Tensor) -> torch.Tensor:# x: (B, C, T, H, W)if self.padding_flag == 0:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding, 0),mode="replicate", # TODO: check if this is necessary)else:x = F.pad(x,pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),)return super().forward(x)def set_padding_one_frame(self):def _set_padding_one_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 1for sub_name, sub_mod in module.named_children():_set_padding_one_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_one_frame(name, module)def set_padding_more_frame(self):def _set_padding_more_frame(name, module):if hasattr(module, 'padding_flag'):print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))module.padding_flag = 2for sub_name, sub_mod in module.named_children():_set_padding_more_frame(sub_name, sub_mod)for name, module in self.named_children():_set_padding_more_frame(name, module)
Adaptive Layer Norm (adaLN) block
这是DiT里面最核心的设计之一,adaptive normalization layer(adaLN
),将transformer block的layer norm替换为adaLN
。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter 𝛾和shift parameter 𝛽 用condition embedding来替代。
原始的Layer Norm设计:
class LayerNorm:def __init__(self, feature_dim, epsilon=1e-6):self.epsilon = epsilonself.gamma = np.random.rand(feature_dim) # scale parametersself.beta = np.random.rand(feature_dim) # shift parametrsdef __call__(self, x: np.ndarray) -> np.ndarray:"""Args:x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)return:x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)"""_mean = np.mean(x, axis=-1, keepdims=True)_std = np.var(x, axis=-1, keepdims=True)x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.betareturn x_layer_norm
DiT中的adaLN设计:
class DiTBlock(nn.Module):"""A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""def __init__(self,hidden_size,num_heads,mlp_ratio=4.0,enable_flash_attn=False,enable_layernorm_kernel=False,):super().__init__()self.hidden_size = hidden_sizeself.num_heads = num_headsself.enable_flash_attn = enable_flash_attnmlp_hidden_dim = int(hidden_size * mlp_ratio)self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.attn = Attention(hidden_size,num_heads=num_heads,qkv_bias=True,enable_flash_attn=enable_flash_attn,)self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))def forward(self, x, c):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))return x
完整DiT Block 设计
好了,到这里已经是把主要的DiT构建出来了,接下来把DiT结构堆积28层,构成了现在的DiT结构:
@MODELS.register_module()
class DiT(nn.Module):"""Diffusion model with a Transformer backbone."""def __init__(self,input_size=(16, 32, 32),in_channels=4,patch_size=(1, 2, 2),hidden_size=1152,depth=28,num_heads=16,mlp_ratio=4.0,class_dropout_prob=0.1,learn_sigma=True,condition="text",no_temporal_pos_emb=False,caption_channels=512,model_max_length=77,dtype=torch.float32,enable_flash_attn=False,enable_layernorm_kernel=False,enable_sequence_parallelism=False,):super().__init__()self.learn_sigma = learn_sigmaself.in_channels = in_channelsself.out_channels = in_channels * 2 if learn_sigma else in_channelsself.hidden_size = hidden_sizeself.patch_size = patch_sizeself.input_size = input_sizenum_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])self.num_patches = num_patchesself.num_temporal = input_size[0] // patch_size[0]self.num_spatial = num_patches // self.num_temporalself.num_heads = num_headsself.dtype = dtypeself.use_text_encoder = not condition.startswith("label")if enable_flash_attn:assert dtype in [torch.float16,torch.bfloat16,], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"self.no_temporal_pos_emb = no_temporal_pos_embself.mlp_ratio = mlp_ratioself.depth = depthassert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)if not self.use_text_encoder:num_classes = int(condition.split("_")[-1])self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)else:self.y_embedder = CaptionEmbedder(in_channels=caption_channels,hidden_size=hidden_size,uncond_prob=class_dropout_prob,act_layer=approx_gelu,token_num=1, # pooled token)self.t_embedder = TimestepEmbedder(hidden_size)self.blocks = nn.ModuleList([DiTBlock(hidden_size,num_heads,mlp_ratio=mlp_ratio,enable_flash_attn=enable_flash_attn,enable_layernorm_kernel=enable_layernorm_kernel,)for _ in range(depth)])self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)self.initialize_weights()self.enable_flash_attn = enable_flash_attnself.enable_layernorm_kernel = enable_layernorm_kerneldef get_spatial_pos_embed(self):pos_embed = get_2d_sincos_pos_embed(self.hidden_size,self.input_size[1] // self.patch_size[1],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef get_temporal_pos_embed(self):pos_embed = get_1d_sincos_pos_embed(self.hidden_size,self.input_size[0] // self.patch_size[0],)pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)return pos_embeddef unpatchify(self, x):c = self.out_channelst, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]pt, ph, pw = self.patch_sizex = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))x = rearrange(x, "n t h w r p q c -> n c t r h p w q")imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))return imgsdef forward(self, x, t, y):"""Forward pass of DiT.x: (B, C, T, H, W) tensor of inputst: (B,) tensor of diffusion timestepsy: list of text"""# origin inputs should be float32, cast to specified dtypex = x.to(self.dtype)if self.use_text_encoder:y = y.to(self.dtype)# embeddingx = self.x_embedder(x) # (B, N, D)x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)x = x + self.pos_embed_spatialif not self.no_temporal_pos_emb:x = rearrange(x, "b t s d -> b s t d")x = x + self.pos_embed_temporalx = rearrange(x, "b s t d -> b (t s) d")else:x = rearrange(x, "b t s d -> b (t s) d")t = self.t_embedder(t, dtype=x.dtype) # (N, D)y = self.y_embedder(y, self.training) # (N, D)if self.use_text_encoder:y = y.squeeze(1).squeeze(1)condition = t + y# blocksfor _, block in enumerate(self.blocks):c = conditionx = auto_grad_checkpoint(block, x, c) # (B, N, D)# final processx = self.final_layer(x, condition) # (B, N, num_patches * out_channels)x = self.unpatchify(x) # (B, out_channels, T, H, W)# cast to float32 for better accuracyx = x.to(torch.float32)return xdef initialize_weights(self):# Initialize transformer layers:def _basic_init(module):if isinstance(module, nn.Linear):if module.weight.requires_grad_:torch.nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)self.apply(_basic_init)# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):w = self.x_embedder.proj.weight.datann.init.xavier_uniform_(w.view([w.shape[0], -1]))nn.init.constant_(self.x_embedder.proj.bias, 0)# Initialize timestep embedding MLP:nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)# Zero-out adaLN modulation layers in DiT blocks:for block in self.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight, 0)nn.init.constant_(block.adaLN_modulation[-1].bias, 0)# Zero-out output layers:nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.final_layer.linear.weight, 0)nn.init.constant_(self.final_layer.linear.bias, 0)# Zero-out text embedding layers:if self.use_text_encoder:nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
这篇关于Open-Sora代码详细解读(1):解读DiT结构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!