Transformer实战-系列教程10:SwinTransformer 源码解读3

2024-02-07 17:44

本文主要是介绍Transformer实战-系列教程10:SwinTransformer 源码解读3,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

5、SwinTransformerBlock类

class SwinTransformerBlock(nn.Module):def extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

5.1 构造函数

SwinTransformerBlock 是 Swin Transformer 模型中的一个基本构建块。它结合了自注意力机制和多层感知机(MLP),并通过窗口划分和可选的窗口位移来实现局部注意力

def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioif min(self.input_resolution) <= self.window_size:self.shift_size = 0self.window_size = min(self.input_resolution)assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)if self.shift_size > 0:H, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask)
  1. dim:输入特征的通道数。
  2. input_resolution:输入特征的分辨率(高度和宽度)
  3. num_heads:自注意力头的数量
  4. window_size:窗口大小,决定了注意力机制的局部范围
  5. shift_size:窗口位移的大小,用于实现错位窗口多头自注意力(SW-MSA)
  6. mlp_ratio:MLP隐层大小与输入通道数的比率
  7. qkv_bias:QKV的偏置
  8. qk_scale:QKV的缩放因子
  9. drop:丢弃率
  10. drop_path:分别控制QKV的偏差、缩放因子、丢弃率、注意力丢弃率和随机深度率
  11. norm_layer:激活层和标准化层,默认分别为 GELU 和 LayerNorm
  12. WindowAttention:窗口注意力模块
  13. Mlp:一个包含全连接层、激活函数、Dropout的模块
  14. img_mask :图像掩码,用于生成错位窗口自注意力
  15. h_slicesw_slices:水平和垂直方向上的切片,用于划分图像掩码
  16. cnt :计数器,标记不同的窗口
  17. mask_windows :图像掩码划分为窗口,并将每个窗口的掩码重塑为一维向量
  18. window_partition
  19. attn_mask :注意力掩码,用于在自注意力计算中排除窗口外的位置
  20. register_buffer:注意力掩码注册为一个模型的缓冲区

5.2 前向传播

def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xx_windows = window_partition(shifted_x, self.window_size)x_windows = x_windows.view(-1, self.window_size * self.window_size, C)attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, Cattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' Cif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)x = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x
  1. 原始输入: torch.Size([4, 3136, 96]),输入的是一个长度为3136的序列,每个向量的维度为96,在
    被多次调用的时候,维度也发生了变化原始输入: torch.Size([4, 784, 192])、torch.Size([4, 196, 384])、torch.Size([4, 49, 768])
  2. H,W=[ 56,56],输入分辨率中的高度和宽度
  3. B, L, C=[ 4,3136,96],当前输入的维度,批次大小、序列长度和向量的维度
  4. norm1(x): torch.Size([4, 3136, 96]),经过一个层归一化,维度不变
  5. x.view(B, H, W, C): torch.Size([4, 56, 56, 96]),将序列重塑为(Batch_size,Height,Width,Channel)的形状
  6. shifted_x: torch.Size([4, 56, 56, 96]),位移操作后的x
  7. x_windows: torch.Size([256, 7, 7, 96]),将位移后的特征图划分为窗口
  8. x_windows: torch.Size([256, 49, 96]),将窗口重塑为一维向量,以便进行自注意力计算
  9. attn_windows: torch.Size([256, 7, 7, 96]),对每个窗口应用窗口注意力机制,考虑到可能的注意力掩码
  10. shifted_x: torch.Size([4, 56, 56, 96]),注意力操作后的窗口重塑回原始形状,并将它们合并回完整的特征图
  11. torch.Size([4, 56, 56, 96]),如果进行了循环位移,则执行逆向循环位移操作,以恢复原始特征图的位置
  12. torch.Size([4, 3136, 96]),特征图重塑回原始的[B, L, C]形状
  13. torch.Size([4, 3136, 96]),应用残差连接,并通过随机深度(如果设置了的话)
  14. torch.Size([4, 3136, 96]),应用第二个标准化层,然后是MLP,并再次应用随机深度,完成残差连接的最后一步。

这篇关于Transformer实战-系列教程10:SwinTransformer 源码解读3的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/688455

相关文章

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

Java设计模式---迭代器模式(Iterator)解读

《Java设计模式---迭代器模式(Iterator)解读》:本文主要介绍Java设计模式---迭代器模式(Iterator),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录1、迭代器(Iterator)1.1、结构1.2、常用方法1.3、本质1、解耦集合与遍历逻辑2、统一

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

MySQL之InnoDB存储页的独立表空间解读

《MySQL之InnoDB存储页的独立表空间解读》:本文主要介绍MySQL之InnoDB存储页的独立表空间,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、独立表空间【1】表空间大小【2】区【3】组【4】段【5】区的类型【6】XDES Entry区结构【

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现