本文主要是介绍Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
🚩🚩🚩Transformer实战-系列教程总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
https://download.csdn.net/download/weixin_50592077/88809977?spm=1001.2014.3001.5501
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)
1、项目配置
本项目来自SwinTransformer 的GitHub官方源码:
Image Classification: Included in this repo. See get_started.md for a quick start.
Object Detection and Instance Segmentation: See Swin Transformer for Object Detection.
Semantic Segmentation: See Swin Transformer for Semantic Segmentation.
Video Action Recognition: See Video Swin Transformer.
Semi-Supervised Object Detection: See Soft Teacher.
SSL: Contrasitive Learning: See Transformer-SSL.
SSL: Masked Image Modeling: See get_started.md#simmim-support.
Mixture-of-Experts: See get_started for more instructions.
Feature-Distillation: See Feature-Distillation.
此处包含多个版本(分类、检测、分割、视频 ),但是仅仅学习算法建议选择第一个图像分类的基础版本就可以了
安装需求:
pip install timm==0.4.12
pip install yacs==0.1.8
pip install termcolor==1.1.0
pytorch
opencv
Apex(linux版本)
原本的数据是imagenet,这个数据太多了,有很多开源的imagenet小版本,本文配套的资源就是已经配好的imagenet小版本,目录信息、数据标注、数据划分都已经做好了
本项目的执行文件就是main.py,源码我已经修改了部分
配置参数:
--cfg configs/swin_tiny_patch4_window7_224.yaml
--data-path imagenet
--local_rank 0
--batch-size 4
–local rank 0这个参数表示的是分布式训练,直接用当前的这个就好
2、SwinTransformer类
打开models有两个构建模型的源码:
build.py
swin_transformer.py
构建模型的部分主要就在swin_transformer.py,一共有600多行代码
首先看SwinTransformer类的前向传播函数:
class SwinTransformer(nn.Module):def forward(self, x):x = self.forward_features(x)x = self.head(x)return x
打印这个过程的shape值:
- 原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像,4是batch,3是通道数,图像是224*244的长宽
- self.forward_features(x):torch.Size([4, 768]),经过forward_features函数后,变成了768维的向量
- self.head(x):torch.Size([4, 1000]),head是一个全连接层,很显然这个1000是最后的分类数
所以整个体征提取的过程都在self.forward_features()函数中:
def forward_features(self, x):x = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)for layer in self.layers:x = layer(x)x = self.norm(x) # B L Cx = self.avgpool(x.transpose(1, 2)) # B C 1x = torch.flatten(x, 1)return x
-
原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像
-
patch_embed: torch.Size([4, 3136, 96]),图像经过patch_embbeding变成一个Transformer需要的序列,相当于序列是3136个向量,每个向量维度是96。这个过程通常包括将图像分割成多个patches,然后将每个patch线性投影到一个指定的维度。
-
if self.ape: x = x + self.absolute_pos_embed
,如果模型配置了绝对位置编码(self.ape
为真),这行代码会将绝对位置嵌入加到patch的嵌入上。绝对位置嵌入提供了每个patch在图像中位置的信息,帮助模型理解图像中不同部分的空间关系, 不改变维度 -
pos_drop: torch.Size([4, 3136, 96]),一层Dropout
-
layer: torch.Size([4, 784, 192]),for循环主要是Swin Transformer Block的堆叠
-
layer: torch.Size([4, 196, 384]),4次循环,序列长度减小
-
layer: torch.Size([4, 49, 768]),特征图个数增多,即向量维度变大
-
layer: torch.Size([4, 49, 768]),最后一次维度不变
-
norm: torch.Size([4, 49, 768]),层归一化,维度不变
-
avgpool: torch.Size([4, 768, 1]),平均池化
-
flatten: torch.Size([4, 768]),拉平操作,去掉多余的维度
SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)
这篇关于Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!