Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

本文主要是介绍Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

6、WindowAttention类

6.1 构造函数

class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)
  1. dim:输入特征维度
  2. window_size:窗口大小
  3. num_heads:多头注意力头数
  4. head_dim:每头注意力的头数
  5. scale :缩放因子
  6. relative_position_bias_table:相对位置偏置表,它对每个头存储不同窗口位置之间的偏置,以模拟位置信息
  7. coords_h 、coords_w、coords:窗口内每个位置的坐标
  8. coords_flatten :将坐标展平,为计算相对位置做准备
  9. 第1个relative_coords:计算窗口内每个位置相对于其他位置的坐标差
  10. 第2个relative_coords:重排坐标差的维度以符合预期的格式
  11. relative_coords[:, :, 0]、relative_coords[:, :, 1]、relative_coords[:, :, 0]:调整坐标差,使其能够映射到相对位置偏置表中的索引
  12. relative_position_index :计算每对位置之间的相对位置索引
  13. register_buffer:将相对位置索引注册为模型的缓冲区,这样它就不会在训练过程中被更新
  14. qkv :创建一个线性层,用于生成QKV
  15. attn_drop、proj、proj_drop:初始化注意力dropout、输出投影层及其dropout
  16. trunc_normal_:使用截断正态分布初始化相对位置偏置表
  17. softmax :初始化softmax层,用于计算注意力权重

6.2 前向传播

在Swin Transformer中,一共有4个stage,每次stage都分别包含W-MSA和SW-MSA,W-MSA和SW-MSA都是在执行WindowAttention模块,也就是说在一次batch的执行过程中WindowAttention的前向传播过程会执行8次。
每次stage执行后,都会进行下采样操作,这个下采样操作,长宽会减半特征图会增多。

    def forward(self, x, mask=None):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x
  1. B_, N, C = x.shape原始输入: torch.Size([256, 49, 96]),B_, N, C即原始输入的维度
  2. qkv = self.qkv(x).reshape...qkv: torch.Size([3, 256, 3, 49, 32]),被重塑的一个五维张量,分别代表qkv三个维度、256个窗口、3个注意力头数但是不会一直是3越往后会越多、49是一个窗口有7*7=49元素、每个头的特征维度。在之前的Transformer以及Vision Transformer中,都是用x接上各自的全连接后分别生成QKV,这这里直接一起生成了。
  3. q: torch.Size([256, 3, 49, 32]),k: torch.Size([256, 3, 49, 32]),v: torch.Size([256, 3, 49, 32]),从qkv中分解出q、k、v,而且已经包含了多头注意力机制
  4. attn: torch.Size([256, 3, 49, 49]),attn是q和k的点积
  5. relative_position_bias: torch.Size([49, 49, 3]),从相对位置偏置表中索引出每对位置之间的偏置,并重塑以匹配注意力分数的形状
  6. relative_position_bias: torch.Size([3, 49, 49]),重新排列,位置编码在Transformer中一直当成偏置加进去的,而这个位置编码是对一个窗口的,所以每一个窗口的都对应了相同的位置编码
  7. attn: torch.Size([256, 3, 49, 49]),将位置编码加到注意力分数上,到这里就算完了全部的注意力机制了
  8. attn: torch.Size([256, 3, 49, 49]),掩码加到注意力分数上,使用softmax函数归一化注意力分数,得到注意力权重,应用注意力dropout。
  9. x: torch.Size([256, 49, 96]),使用注意力权重对v向量进行重构,然后对结果进行转置和重塑
  10. x: torch.Size([256, 49, 96]),将加权的注意力输出通过一个线性投影层,应用输出dropout,这就是最后WindowAttention的输出,一共256个窗口,每个窗口有49个特征,每个特征对应96维的向量

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

这篇关于Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/691186

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

MCU7.keil中build产生的hex文件解读

1.hex文件大致解读 闲来无事,查看了MCU6.用keil新建项目的hex文件 用FlexHex打开 给我的第一印象是:经过软件的解释之后,发现这些数据排列地十分整齐 :02000F0080FE71:03000000020003F8:0C000300787FE4F6D8FD75810702000F3D:00000001FF 把解释后的数据当作十六进制来观察 1.每一行数据

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL