VIT中PatchEmbed、MultiHeadAttention代码详解(PyTorch)

2023-11-01 01:30

本文主要是介绍VIT中PatchEmbed、MultiHeadAttention代码详解(PyTorch),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文对PatchEmbed和MulitHeadAttention进行代码的详细解读,希望可以给同样被此处困扰的小伙伴提供一些帮助,如有错误,还望指正。

文章目录

  • 一、VIT简单介绍
  • 二、PatchEmbed
      • 1.PatchEmbed的目的
      • 2.代码的执行过程
      • 3.注意
      • 4.完整代码解释
      • 5.代码简化版
  • 三、Attention机制
      • 1.self-attention和MultiHeadAttention的区别
      • 2.部分代码解释
      • 3.实现思想
      • 4.完整代码解释

一、VIT简单介绍

相信看到本文的小伙伴基本都是了解了VIT为何物,否则也不会对PatchEmbed感兴趣,所以本文只对VIT做一个简单的介绍。
VIT是Vision Transformer的简称,是将Transformer模型运用在图片上的一个重要的网络模型,也是Transformer四大核心模块之一。
其思想就是将图片分块再拼接形成如同文本数据一般的序列数据,方便将数据输入到Transformer网络中。如图为VIT的网络模型结构,本文不会讨论其所有的子模块,而是选择器PatchEmbed模块和MultiHeadAttention模块进行代码的详解。
在这里插入图片描述

二、PatchEmbed

1.PatchEmbed的目的

将输入的图片用分块再拼接的思想转化为序列的形式,因为Transformer只能接收序列数据。注意这里只是用了分块再拼接的思想,看代码的时候,不需要这个思想也是可以看懂的,如果理解不了,就直接看代码就可以了。或者说看懂了代码之后就理解这个思想了。

2.代码的执行过程

1.输入的图片size为[B, 3, 224, 224]
2.确定好分块的大小为patch_size=16,确定好16,就可以确定卷积核的大小为16,步长为16,即patch_size = kernel_size = stride
3.首先图片通过卷积nn.Conv2d(3, 768,(16,16), (16,16))后size变为[B, 768, 14, 14]
4.再经历一次flatten(2),变为[B, 768, 14*14=196],这里flatten(2)的2意思是在位序为2开始进行展平
5.最后经过一次转置transpose(1, 2),size变成[B, 196, 768]

3.注意

许多人还是不理解为什么要将图片的size转成[B, 196, 768],因为Transformer接受的是序列格式的数据,而不是图片4维【B,C,H,W】的格式,序列如文本数据的格式为【B,N,C】,N为token的个数,C为每个token的维度。只有将图片通过分块拼接成序列形式,才可以输入到transformer网络中。

4.完整代码解释

class PatchEmbed(nn.Module):def __init__(self, img_size=224,  # 输入图片大小patch_size=16,  # 分块大小in_c=3,  # 输入图片的通道数embed_dim=768,  # 经过PatchEmbed后的分块的通道数norm_layer=None): # 标准化层super(PatchEmbed, self).__init__()img_size = (img_size, img_size)  #将img_size、patch_size转为元组patch_size = (patch_size, patch_size)self.img_size = img_sizeself.patch_size = patch_size# // 是一种特殊除号,作用为向下取整# grid_size:分块后的网格大小,即一张图片切分为块后形成的网格结构,理解不了不用理解,就是为了求出分块数目的self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]  # 分块数量self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)  # 分块用的卷积# 如果norm_layer为None,就使用一个空占位层,就是看要不要进行一个标准化self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()# nn.Identity()层是用来占位的,没什么用def forward(self, x):B, C, H, W = x.shape# assert是python的断言,当后面跟的是False时就会停下assert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})""""x = self.proj(x).flatten(2).transpose(1, 2)1.第一步将x做卷积 [B, 3, 224, 224] -> [B, 768, 14, 14]2.从位序为2的维度开始将x展平 [B, 768, 14, 14] -> [B, 768, 196]3.转置[B, 196, 768] 得到batch批次,每个批次有196个“词”,每个“词”有768维"""x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return x

5.代码简化版

上述代码可以简化为如下代码,不同之处在于使用了Rearrange函数
Rearrange函数可以很方便的操作张量的shape,直接替代了view和reshape方法
Rearrange函数的简单使用如下:

from einops import rearrangeimg = torch.randn(1, 3, 224, 224)
print(img.shape)
patch = rearrange(img, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=16, s2=16)
"""
解释:
img [1, 3, 224, 224]
【b c (h s1) (w s2)】其中s1=s2=16,故可知h=w=224/16=14
故【b (h w) (s1 s2 c)】=[b 196 768]
"""
print(patch.shape)

简化版的PatchEmbed如下:

class PatchEmbed(nn.Module):def __init__(self, patch_size=16, in_channel=3, emb_size=768):super(PatchEmbed, self).__init__()self.patch_embed_linear = nn.Sequential(# 将原始图片切分为16*16并将其拉平Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),nn.Linear(patch_size * patch_size * in_channel, emb_size))def forward(self, x):x = self.patch_embed_linear(x)return x

除此之外,使用卷积操作也是可以的:

class PatchEmbedding(nn.Module):def __init__(self, in_channel=3, embed_dim=768, patch_size=16):super(PatchEmbedding, self).__init__()self.patch_embed_conv = nn.Sequential(# [b, 3, 224, 224]nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size),# [b, 768, 14, 14]Rearrange('b c h w -> b (h w) c')# [b, 196, 768])def forward(self, x):x = self.patch_embed_conv(x)return x

三、Attention机制

1.self-attention和MultiHeadAttention的区别

自注意力机制和多头注意力机制原理上几乎差不多,而二者的不同之处在于自注意机制是用一组QKV来使token获取上下文信息。
而由下图可知,多头注意力机制是使用多组QKV来让token得到多组的上下文信息,最后使用一个W0矩阵对得到的所有Zi进行整合。
在这里插入图片描述

2.部分代码解释

在下面的完整代码中,有如下一行代码,刚好找到了图解,所以单独拿出来,以便于理解。

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

上面的代码可以用下图来理解,通过一次的全连接操作,就可以生成x的QKV矩阵
在这里插入图片描述
通过上图的解释,不难得出,该行代码可以用如下三行代码来替换:

self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)

对于代码中的参数qk_scale,记住这是公式中的根号dk就可以了。在这里插入图片描述

3.实现思想

代码在实现多头注意力机制的时候,使用了一次计算多组的方法,即多头所用的qkv,一次性生成,各组间的计算也是一次性通过矩阵计算的方式并行计算完成。

# 一次性生成
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 一次性计算
attn = (q @ k.transpose(-2, -1)) * self.scale  # 计算相似度
x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # 计算注意力值

在这两行代码中,q、k、v代表的就是多组的qkv矩阵,通过一个矩阵计算的算式即可将每一组的qkv都计算出来。

4.完整代码解释

class Attention(nn.Module):# 在实现上多头注意实际上就是在单头的基础上增添num_heads个维度,且在最后输出attention时增加一个权重矩阵def __init__(self,dim,  # 输入token的dim 768num_heads=8,qkv_bias=False,  # 在生成qkv时是否使用偏置qk_scale=None,  # q、k的缩放因子,保证内积计算不会受到向量长度的影响attn_drop_ratio=0,proj_drop_ratio=0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // self.num_heads  # 计算每一个head需要传入的dim  768/8=96self.scale = qk_scale or head_dim ** -0.5  # 若给定qk_scale则使用其作为缩放因子,若没给则使用后者self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)"""self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)目的:得到对应x的q、k、v矩阵,其中x是token_num个dim维的token组成的矩阵过程:x:[token_num, dim] 经过Linear层后得到矩阵 qkv[token_num, dim * 3]将qkv矩阵按dim进行拆分,就可以得到size为[token_num, dim]的q、k、v三个矩阵故该线性层可以拆分为:self.q = nn.Linear(dim, dim)self.k = nn.Linear(dim, dim)self.v = nn.Linear(dim, dim)新的理解:经过一个线性层,就是让输入矩阵乘一个[in_channel, out_channel]的矩阵如:x[token_num, dim] 经过 Linear(dim, dim*3) 就是乘一个[dim, dim*3]的矩阵,最后变成[token_num, dim*3]"""self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)  # 将每一个head的结果拼接的时候所乘的权重self.proj_drop = nn.Dropout(proj_drop_ratio)def forward(self, x): # x是经历了PatchEmbed后的xB, N, C = x.shape # 【B,N,C】:【B, 196, 768】 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)"""输入x:[batch, N, C]1.self.qkv(x) : qkv:[B, N, 3*C]2.reshape() : qkv:[B, N, 3, self.num_heads, C // self.num_heads]3.permute(2, 0, 3, 1, 4) : qkv:[3, B, self.num_heads, N, C // self.num_heads]size说明:3:将qkv分为q、k、v三个矩阵 | q:[B, self.num_heads, N, C(dim)]B: 每个q/k/v矩阵都对应有B个batch | 单个q : [self.num_heads, N, C(dim)]self.num_heads : 在根据头数,将q/k/v划分为对应头数个矩阵 | 每个头:[N, C(dim)]:  反正就是将qkv划分为和输入x一致大小的矩阵"""q, k, v = qkv[0], qkv[1], qkv[2]  # 【B,8,N,96】:【batch,8个头,N个词,每个词96维】attn = (q @ k.transpose(-2, -1)) * self.scale  # 计算相似度"""q、k、v都是【B, 8, N, 96】的矩阵,就是多头注意力机制的多个qkv然后利用attn = (q @ k.transpose(-2, -1)) * self.scale公式让这多组q@k一次性计算出来x = (attn @ v).transpose(1, 2).reshape(B, N, C)也是一样的,通过一个公式将多组的	 softmax(q@k)@v计算出来"""attn = attn.softmax(dim=-1)  # 计算概率attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # 计算注意力值x = self.proj(x)  # 乘最后的权重矩阵x = self.proj_drop(x)return x

这篇关于VIT中PatchEmbed、MultiHeadAttention代码详解(PyTorch)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/319525

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1