Vit Transformer

2024-03-27 08:52
文章标签 transformer vit

本文主要是介绍Vit Transformer,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一 VitTransformer 介绍

vit : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

        论文是基于Attention Is All You Need,由于图像数据和词数据数据格式不一样,经典的transformer不能处理图像数据,在视觉领域的应用有限。本文提出的方法可以将transformer直接应用图像分类任务,引入Patch Embedding,位置编码等方法,克服了Transformer在处理图像数据时的限制。整体流程如下。

从图中可以看出, Vision Transformer 主要有三个部分组成: 1 ) 第一部分是Linear Projection of Flattened Patches ,也就是 Emdedding 层,主要的工作就是将图像数据转换成transformer可以处理的数据格式。2)第二部分是Transformer Encoder部分,它是vit 最核心的组件(原始的NLP的transformer还有Decoder部分)。它主要是层归一化,多头注意力机制,MLP,Dropout/DropPath四个小block组成,用于学习图像数据。3) 第三部分就是MLP head ,用于分类。

二 PatchEmbedding & Positional Encoding

        首先,每个图像被分割成一系列不重叠的块(16x16或者 32x32),然后做一个线性的embedding ,由于这些块如果并行的输入到transformer中,不提供位置信息,模型不知道这些块的顺序。因此要加一个 positional encoding。 

        在实际的实现上,图像数据是[batch_size, C , H, W] 的格式,要将其变成[batch_size , token_len , dim],其中token_len 可以理解成图像patch token的数量。以[4,3,224,224]的图像为例子,首先我们模拟分割块,对于一个图像,我们要将其分割成 (H*w)/(patch_size*patch_size)个patches,即(224x224)// (16x16) = 196个 patches 。每个patch的大小是(3,16,16),然后我们将其flatten一个768( 3x16x16)dim的 token。这样数据格式就变成[4,196,768]。

代码分割图像块 : 

def split_patches(x, patch_size=16):batch_size, channels, height, width = x.shapex = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)x = x.reshape(batch_size, -1, channels * patch_size * patch_size)return x

当然这个过程可以通过卷积实现,官方代码其实就是用卷积来实现的。

class PatchEmbed(nn.Module) :def __init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768,norm_layer=None):super().__init__()img_size = (img_size,img_size)patch_size = (patch_size,patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0],img_size[1]//patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1] # 卷积层self.proj = nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)# 归一化self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self,x) :B,C,H,W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}x{W} does not match input image size ({self.img_size[0]}x{self.img_size[1]}"x = self.proj(x).flatten(2).transpose(1,2)x = self.norm(x)return x

 positional Embedding 

由于输入的图像数据patch序列没有能够表达patch之间相对位置关系,因此需要加入位置编码(Positional encoding)这个特征,为了得到不同位置的对应的编码,Transformer模型使用不同频率的正余弦函数

PE(Pos,2i) = sin(\frac{pos}{10000^{2i/d}})

PE(Pos,2i+1) = cos(\frac{pos}{10000^{2i/d}})

 其中 pos是表示token(flattened image patch)的位置,2i和2i+1表示位置编码向量中对应的维度,d是对应位置编码的总维度。

def add_positional_encoding(x, max_len):batch_size, patch_numbers, dim = x.shapeposition = torch.arange(max_len).reshape(-1, 1)div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))pe = torch.zeros((max_len, dim))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)x += pe[:patch_numbers]return x

三 Self-Attention 

        其中最核心的部分就是对于注意力部分了。在基于Transformer的机器翻译模型中,要建模源语言和目标语言任意两个单词的依赖关系,引入自注意力K(键)Q(查询)V(值)。这三个用来计算上下文单词所对应的权重得分,这些权重反映了在编码当前单词时,对于上下文不同部分所需要关注程度。

    同样在vision transformer中,对于一个个image patch token 来说,也需要建模任意 token之间的相互关注关系,当处理当前token时,哪些token与它有更高的关联度。

        上图是论文中的Scaled Dot-Product AttentionMulti-head Attention,我们首先定义三个矩阵 Q,K,V,这三个矩阵是由 输入X([4,196,768])分别经过三个权重矩阵$w_{q},w_{k},w_{v}$得到的。其中 Q 矩阵和K 矩阵,V矩阵是“同源”的,因为它们都是来自于同一个输入序列(图像patch token)的某种表示(线性变换的嵌入表示)。

根据Attenton分数的计算公式,Q(shape=[4,196,768])左乘一个K(shape=[4,196,768])矩阵的转置,得到一个相似度矩阵(shape=[4,196,196]),为了防止过大的相似度数值在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,因此使用一个缩放因子\sqrt{d}缩放来稳定优化。放缩后的得分经过Softmax函数归一化为概率后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。

def self_attention(x, w_q, w_k, w_v):query = torch.matmul(x, w_q)key = torch.matmul(x, w_k)value = torch.matmul(x, w_v)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))attention_scores = softmax(scores)output = torch.matmul(attention_scores, value)return attention_scores, output

过程如下

四 Multi-head self-attention (MSA)

        为了进一步提升自注意力机制的全局信息聚合能力,提出了Multi-head attention机制,具体来说,上下文的每个token 向量的表示x_{i}的经过多组的线性{W_{j}^{Q},W_{j}^{K},W_{j}^{V}}映射到不同的表示空间。计算出不同子空间得到的attention score得到{Z_{j}}_{j=1}^{N},再用一个线性变换w^{o} 用于综合不同子空间中的上下文表示形成最后的输出。

import torch
import torch.nn.functional as Fclass MultiHeadAttention(torch.nn.Module):def __init__(self, input_dim, num_heads, head_dim):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = head_dimassert input_dim % self.num_heads == 0self.projection_dim = input_dim // self.num_heads# 定义 权重矩阵 self.weight_q = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_k = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_v = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_combine = torch.nn.Parameter(torch.randn(num_heads * self.projection_dim, input_dim))def forward(self, x):batch_size, seq_length, _ = x.size()queries = torch.matmul(x, self.weight_q)keys = torch.matmul(x, self.weight_k)values = torch.matmul(x, self.weight_v)queries = queries.view(batch_size, seq_length, self.num_heads, self.projection_dim)keys = keys.view(batch_size, seq_length, self.num_heads, self.projection_dim)values = values.view(batch_size, seq_length, self.num_heads, self.projection_dim)queries = queries.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# 计算得分scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.projection_dim ** 0.5)attention_weights = F.softmax(scores, dim=-1)attention_output = torch.matmul(attention_weights, values)attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)output = torch.matmul(attention_output, self.weight_combine)return outputinput_dim = 64
num_heads = 8
head_dim = input_dim // num_heads
seq_length = 10
batch_size = 4multihead_attention = MultiHeadAttention(input_dim, num_heads, head_dim)
x = torch.rand(batch_size, seq_length, input_dim)
output = multihead_attention(x)print("输入形状:", x.shape)
print("输出形状:", output.shape)


五 代码实现

import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
import numpy as npdef softmax(x):return torch.nn.functional.softmax(x, dim=-1)def split_patches(x, patch_size=16):batch_size, channels, height, width = x.shapex = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)x = x.reshape(batch_size, -1, channels * patch_size * patch_size)return xdef add_positional_encoding(x, max_len):batch_size, patch_numbers, dim = x.shapeposition = torch.arange(max_len).reshape(-1, 1)div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))pe = torch.zeros((max_len, dim))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)x += pe[:patch_numbers]return xdef plot_heatmap(scores,index,name ):plt.figure(figsize=(8, 6))plt.imshow(scores, cmap='hot', interpolation='nearest')plt.xlabel('Keys')plt.ylabel('Queries')plt.title(f'Attention Scores Heatmap {name}')plt.colorbar()plt.savefig(f"./attention_heatmap{index}.png")# plt.show()
def self_attention(x, w_q, w_k, w_v):query = torch.matmul(x, w_q)key = torch.matmul(x, w_k)value = torch.matmul(x, w_v)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))attention_scores = softmax(scores)output = torch.matmul(attention_scores, value)return attention_scores, outputdef plot_heatmap_on_image(image, attention_scores, patch_size=16,index=0):# 对每个patch的注意力分数求平均attention_scores_mean = attention_scores.mean(dim=1)# 将注意力分数转换为与原始图像大小相匹配的热力图attention_map = attention_scores_mean.view(1, 1, int(224 / patch_size), int(224 / patch_size))attention_map = torch.nn.functional.interpolate(attention_map, size=(224, 224), mode='bilinear', align_corners=False)attention_map = attention_map.squeeze().cpu().detach().numpy()plt.figure(figsize=(6, 6))plt.imshow(image)plt.imshow(attention_map, cmap='jet', alpha=0.5)plt.axis('off')plt.savefig(f'attention_map{index}.png')# plt.show()if __name__ == '__main__':batch_size = 4channels = 3height = 224width = 224input_dim = 768output_dim = 64transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])image_paths = ["./images/11.jpg","./images/15.jpg","./images/16.jpg","./images/17.jpg"]images = torch.zeros((4, 3, 224, 224), dtype=torch.float32)for i, path in enumerate(image_paths):img = Image.open(path).convert('RGB')img_tensor = transform(img)images[i] = img_tensorpatch_embeddings = split_patches(images, patch_size=16)patch_embeddings_pe = add_positional_encoding(patch_embeddings, max_len=196)w_q = torch.normal(0, 0.01, size=(input_dim, output_dim))w_k = torch.normal(0, 0.01, size=(input_dim, output_dim))w_v = torch.normal(0, 0.01, size=(input_dim, output_dim))attention_scores, output = self_attention(patch_embeddings_pe, w_q, w_k, w_v)# plot_heatmap(attention_scores[0])# for index in range(4) :##     name = image_paths[index].split('/')[-1].split('.')[0]#     plot_heatmap(attention_scores[index],index,name)  # 选择第一张图像的注意力分数进行绘制# 将热力图叠加到原始图像上for index in range(4) :image_path = image_paths[index]img = Image.open(image_path).convert('RGB')img_tensor = transform(img)img_np = np.array(img)plot_heatmap_on_image(img_np, attention_scores[index],16,index=index)

参考

  • LLM(廿四):Transformer 的结构改进与替代方案 - 知乎

  • 【深度学习系列】五、Self Attention_self attention 加入位置信息-CSDN博客

  • NLP(五):Transformer及其attention机制 - 知乎

  • 有关vision transformer的一个综述 - 知乎

  • 为什么 Vision transformer 训练和推理很慢? - 知乎

  • 大规模语言模型:从理论到实践 -- 张奇、桂韬、郑锐、黄萱菁

这篇关于Vit Transformer的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/851577

相关文章

ViT:4 Pruning

实时了解业内动态,论文是最好的桥梁,专栏精选论文重点解读热点论文,围绕着行业实践和工程量产。若在某个环节出现卡点,可以回到大模型必备腔调或者LLM背后的基础模型重新阅读。而最新科技(Mamba,xLSTM,KAN)则提供了大模型领域最新技术跟踪。若对于具身智能感兴趣的请移步具身智能专栏。技术宅麻烦死磕AI架构设计。 视觉转换器(ViT)架构已经广受欢迎,并广泛用于计算机视觉应用。然而,随着 Vi

2025秋招NLP算法面试真题(二)-史上最全Transformer面试题:灵魂20问帮你彻底搞定Transformer

简单介绍 之前的20个问题的文章在这里: https://zhuanlan.zhihu.com/p/148656446 其实这20个问题不是让大家背答案,而是为了帮助大家梳理 transformer的相关知识点,所以你注意看会发现我的问题也是有某种顺序的。 本文涉及到的代码可以在这里找到: https://github.com/DA-southampton/NLP_ability 问题

算是一些Transformer学习当中的重点内容

一、基础概念         Transformer是一种神经网络结构,由Vaswani等人在2017年的论文Attentions All YouNeed”中提出,用于处理机器翻译、语言建模和文本生成等自然语言处理任务。Transformer同样是encoder-decoder的结构,只不过这里的“encoder”和“decoder”是由无数个同样结构的encoder层和decoder层堆叠组成

Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇

文章目录 1.数据处理部分1.1 下载数据集1.2 数据集预处理1.3 划分train-val并创建Dataset对象1.4 掩码mask操作 数据、评估标准见NLB2021 https://neurallatents.github.io/ 以下代码依据 https://github.com/trungle93/STNDT 原代码使用了 Ray+Config文件进行了参数搜

QuantML-Qlib Model | Kansformer: KAN+Transformer时序模型用于股票收益率预测

QuantML-Qlib Model | Kansformer: KAN+Transformer时序模型用于股票收益率预测 原创 QuantML QuantML 2024-06-18 20:57 上海 Content 之前公众号介绍了几篇KAN的文章,也做过KAN相关的模型: What KAN I say?KAN代码全解析 QuantML-Qlib开发版 | 最新神经网络结构KAN用于因

从零对Transformer的理解(台大李宏毅)

Self-attention layer自注意力 对比与传统cnn和rnn,都是需要t-1时刻的状态然后得到t时刻的状态。我不知道这样理解对不对,反正从代码上看我是这么认为的。而transformer的子注意力机制是在同一时刻产生。意思就是输入一个时间序列,在计算完权重后就直接得到状态。 计算注意力机制的参数q,k,v 那么相信大家肯定看不懂上面的,接下来我来说说a(1,n)怎么计算(如图)

深入解析Transformer:大模型核心技术揭秘

在大模型发展历程中,有两个比较重要点:第一,Transformer 架构。它是模型的底座,但 Transformer 不等于大模型,但大模型的架构可以基于 Transformer;第二,GPT。严格意义上讲,GPT 可能不算是一个模型,更像是一种预训练范式,它本身模型架构是基于 Transformer,但 GPT 引入了“预测下一个词”的任务,即不断通过前文内容预测下一个词。之后,在大量的数据上进

为什么Transformer的编码器和解码器如此重要?一文带你读懂

Transformer 模型是一种基于自注意力(self-attention)机制的深度学习模型,最初是为了解决自然语言处理(NLP)中的序列到序列(sequence-to-sequence)任务而提出的,如机器翻译。Transformer 由编码器和解码器两部分组成,两者都基于相同的自注意力机制,但它们在功能和使用上有所不同。 编码器(Encoder) 编码器是Transformer模型中的

解锁AI的终极秘密:Transformer如何颠覆自然语言处理?

大语言模型系列:Transformer 一、前言二、Transformer诞生的背景三、Transformer横空出世四、Transformer的基本结构4.1. 编码器-解码器架构4.2. 自注意力机制4.3. 多头注意力机制4.4. 前馈神经网络4.5. 残差连接和层归一化 五、Transformer的工作原理5.1输入嵌入5.2位置编码5.3编码器堆栈5.4解码器堆栈5.5输出生成5.

基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码

第一步:准备数据 6种蝴蝶数据:self.class_indict = ["曙凤蝶", "麝凤蝶", "多姿麝凤蝶", "旖凤蝶", "红珠凤蝶", "热斑凤蝶"],总共有900张图片,每个文件夹单独放一种数据 第二步:搭建模型 本文选择一个Vision Transformer网络,其原理介绍如下: Vision Transformer(ViT)是一种基于Transformer架