【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

2024-06-16 12:28

本文主要是介绍【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeatclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = img_size // patch_sizeself.num_patches = self.grid_size ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.proj(x)  # [B, embed_dim, H, W]x = x.flatten(2)  # [B, embed_dim, num_patches]x = x.transpose(1, 2)  # [B, num_patches, embed_dim]return xclass Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])for i in range(depth)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=0.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)cls_token_final = x[:, 0]x = self.head(cls_token_final)return x# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 数据转换和增强
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):model.train()running_loss = 0.0running_corrects = 0for inputs, labels in tqdm(data_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_accdef evaluate(model, criterion, data_loader, device):model.eval()running_loss = 0.0running_corrects = 0with torch.no_grad():for inputs, labels in data_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_acc# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25for epoch in range(num_epochs):train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)val_loss, val_acc = evaluate(model, criterion, val_loader, device)print(f'Epoch {epoch}/{num_epochs - 1}')print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。

这篇关于【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1066453

相关文章

MySQL索引的优化之LIKE模糊查询功能实现

《MySQL索引的优化之LIKE模糊查询功能实现》:本文主要介绍MySQL索引的优化之LIKE模糊查询功能实现,本文通过示例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录一、前缀匹配优化二、后缀匹配优化三、中间匹配优化四、覆盖索引优化五、减少查询范围六、避免通配符开头七、使用外部搜索引擎八、分

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa

Spring Boot 集成 Quartz并使用Cron 表达式实现定时任务

《SpringBoot集成Quartz并使用Cron表达式实现定时任务》本篇文章介绍了如何在SpringBoot中集成Quartz进行定时任务调度,并通过Cron表达式控制任务... 目录前言1. 添加 Quartz 依赖2. 创建 Quartz 任务3. 配置 Quartz 任务调度4. 启动 Sprin

Android实现悬浮按钮功能

《Android实现悬浮按钮功能》在很多场景中,我们希望在应用或系统任意界面上都能看到一个小的“悬浮按钮”(FloatingButton),用来快速启动工具、展示未读信息或快捷操作,所以本文给大家介绍... 目录一、项目概述二、相关技术知识三、实现思路四、整合代码4.1 Java 代码(MainActivi

使用Python实现一个优雅的异步定时器

《使用Python实现一个优雅的异步定时器》在Python中实现定时器功能是一个常见需求,尤其是在需要周期性执行任务的场景下,本文给大家介绍了基于asyncio和threading模块,可扩展的异步定... 目录需求背景代码1. 单例事件循环的实现2. 事件循环的运行与关闭3. 定时器核心逻辑4. 启动与停

基于Python实现读取嵌套压缩包下文件的方法

《基于Python实现读取嵌套压缩包下文件的方法》工作中遇到的问题,需要用Python实现嵌套压缩包下文件读取,本文给大家介绍了详细的解决方法,并有相关的代码示例供大家参考,需要的朋友可以参考下... 目录思路完整代码代码优化思路打开外层zip压缩包并遍历文件:使用with zipfile.ZipFil

Python实现word文档内容智能提取以及合成

《Python实现word文档内容智能提取以及合成》这篇文章主要为大家详细介绍了如何使用Python实现从10个左右的docx文档中抽取内容,再调整语言风格后生成新的文档,感兴趣的小伙伴可以了解一下... 目录核心思路技术路径实现步骤阶段一:准备工作阶段二:内容提取 (python 脚本)阶段三:语言风格调

Java 正则表达式URL 匹配与源码全解析

《Java正则表达式URL匹配与源码全解析》在Web应用开发中,我们经常需要对URL进行格式验证,今天我们结合Java的Pattern和Matcher类,深入理解正则表达式在实际应用中... 目录1.正则表达式分解:2. 添加域名匹配 (2)3. 添加路径和查询参数匹配 (3) 4. 最终优化版本5.设计思

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

Android Mainline基础简介

《AndroidMainline基础简介》AndroidMainline是通过模块化更新Android核心组件的框架,可能提高安全性,本文给大家介绍AndroidMainline基础简介,感兴趣的朋... 目录关键要点什么是 android Mainline?Android Mainline 的工作原理关键