通俗易懂的Spatial Transformer Networks(STN)(二)

2023-12-22 15:58

本文主要是介绍通俗易懂的Spatial Transformer Networks(STN)(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

导读

上一篇通俗易懂的Spatial Transformer Networks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorchnumpy来实现了,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块嵌入到CNN中

STN关键点解读

STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖 θ \theta θ参数来实现的。刚开始的时候我还以为训练STN还需要准备 θ \theta θ标签数据,实际上并不需要。

当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到CNN网络中,通过损失函数计算loss,然后计算梯度更新 θ \theta θ参数,最终STN模块会学习到如何矫正图片。

代码实现

  • 导包
import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import numpy as np
from torchsummary import summary
import argparse
  • 定义网络结构
class STN_Net(nn.Module):def __init__(self,use_stn=True):super(STN_Net, self).__init__()self.conv1 = nn.Conv2d(1,10,kernel_size=5)self.conv2 = nn.Conv2d(10,20,kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320,50)self.fc2 = nn.Linear(50,10)#用来判断是否使用STNself._use_stn = use_stn#localisation net#从输入图像中提取特征#输入图片的shape为(-1,1,28,28)self.localization = nn.Sequential(#卷积输出shape为(-1,8,22,22)nn.Conv2d(1,8,kernel_size=7),#最大池化输出shape为(-1,1,11,11)nn.MaxPool2d(2,stride=2),nn.ReLU(True),#卷积输出shape为(-1,10,7,7)nn.Conv2d(8,10,kernel_size=5),#最大池化层输出shape为(-1,10,3,3)nn.MaxPool2d(2,stride=2),nn.ReLU(True))#利用全连接层回归\theta参数self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3,32),nn.ReLU(True),nn.Linear(32,2*3))self.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0],dtype=torch.float))def stn(self,x):#提取输入图像中的特征xs = self.localization(x)xs = xs.view(-1,10*3*3)#回归theta参数theta = self.fc_loc(xs)theta = theta.view(-1,2,3)#利用theta参数计算变换后图片的位置grid = F.affine_grid(theta,x.size())#根据输入图片计算变换后图片位置填充的像素值x = F.grid_sample(x,grid)return xdef forward(self,x):#使用STN模块if self._use_stn:x = self.stn(x)#利用STN矫正过的图片来进行图片的分类#经过conv1卷积输出的shape为(-1,10,24,24)#经过max pool的输出shape为(-1,10,12,12)x = F.relu(F.max_pool2d(self.conv1(x),2))#经过conv2卷积输出的shape为(-1,20,8,8)#经过max pool的输出shape为(-1,20,4,4)x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x = x.view(-1,320)x = F.relu(self.fc1(x))x = F.dropout(x,training=self.training)x = self.fc2(x)return F.log_softmax(x,dim=1)
  • 加载数据集
def get_dataloader(batch_size):# 加载数据集# 如果GPU可用就用GPU,否则用CPUdevice = torch.device("cuda" if torch.cuda.is_available()else "cpu")# 加载训练集train_dataloader = torch.utils.data.DataLoader(datasets.MNIST(root="D:/dataset", train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=True)# 加载测试集test_dataloader = torch.utils.data.DataLoader(datasets.MNIST(root="D:/dataset", train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=True)return train_dataloader,test_dataloader
  • 训练模型
def train(net,epoch_nums,lr,train_dataloader,per_batch,device):#使用训练模式net.train()#选择梯度下降优化算法optimizer = optim.SGD(net.parameters(),lr=lr)#训练模型for epoch in range(epoch_nums):for batch_idx,(data,label) in enumerate(train_dataloader):data,label = data.to(device),label.to(device)optimizer.zero_grad()pred = net(data)loss = F.nll_loss(pred,label)loss.backward()optimizer.step()if batch_idx % per_batch == 0:print("Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss:{:.6f}".format(epoch,batch_idx * len(data),len(train_dataloader.dataset),100. * batch_idx /len(train_dataloader),loss.item()))
  • 评估模型
def evaluate(net,test_dataloader,device):with torch.no_grad():#使用评估模式net.eval()eval_loss = 0eval_acc = 0for data,label in test_dataloader:data,label = data.to(device),label.to(device)pred = net(data)eval_loss += F.nll_loss(pred,label,size_average=False).item()pred_label = pred.max(1,keepdim=True)[1]eval_acc += pred_label.eq(label.view_as(pred_label)).sum().item()eval_loss /= len(test_dataloader.dataset)print("evaluate set: Average loss: {:.4f},Accuracy:{}/{} ({:.2f}%)\n".format(eval_loss,eval_acc,len(test_dataloader.dataset),100*eval_acc / len(test_dataloader.dataset)))
  • 将pytorch的tensor转换为numpy的array
def tensor_to_array(img_tensor):img_array = img_tensor.numpy().transpose((1,2,0))mean = np.array([0.485,0.456,0.406])std = np.array([0.229,0.224,0.225])img_array = std * img_array + meanimg = np.clip(img_array,0,1)return img
  • 可视化STN变换图片
def visualize_stn(net,dataloader,device):with torch.no_grad():data = next(iter(dataloader))[0].to(device)input_tensor = data.cpu()t_input_tensor = net.stn(data).cpu()in_grid = tensor_to_array(torchvision.utils.make_grid(input_tensor))out_grid = tensor_to_array(torchvision.utils.make_grid(t_input_tensor))f,axarr = plt.subplots(1,2)axarr[0].imshow(in_grid)axarr[0].set_title("input images")axarr[1].imshow(out_grid)axarr[1].set_title("stn transformed images")plt.show()

在这里插入图片描述
通过对比输入图片和经过STN变换后的图片能够很明显发现,经过STN之后能将旋转的图片进行明显的纠正。

  • 参数设置
def parse_args():parse = argparse.ArgumentParser("config stn args")parse.add_argument("--lr",default=0.01,type=float,help="learning rate")parse.add_argument("--epoch_nums",default=20,type=int,help="iterated epochs")parse.add_argument("--use_stn",default=True,type=bool,help="whether to use STN module")parse.add_argument("--batch_size",default=64,type=int,help="batch size")parse.add_argument("--use_eval",default=True,type=bool,help="whether to evaluate")parse.add_argument("--use_visual",default=True,type=bool,help="visual STN transform image")parse.add_argument("--use_gpu",default=True,type=bool,help="whether to use GPU")parse.add_argument("--show_net_construct",default=False,type=bool,help="print net construct info")return parse.parse_args()
  • 主函数
if __name__ == "__main__":args = parse_args()if args.use_gpu and torch.cuda.is_available():device = "cuda"else:device = "cpu"#加载数据集train_loader,test_loader = get_dataloader(args.batch_size)#创建网络net = STN_Net(args.use_stn).to(device)#打印网络的结构信息if args.show_net_construct:summary(net,(1,28,28))#训练模型train(net,args.epoch_nums,args.lr,train_loader,args.batch_size,device)if args.use_eval:#评估模型evaluate(net,test_loader,device)if args.use_visual:#可视化展示效果visualize_stn(net,test_loader,device)

参考:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

这篇关于通俗易懂的Spatial Transformer Networks(STN)(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/524564

相关文章

设计模式之工厂模式(通俗易懂--代码辅助理解【Java版】)

文章目录 1、工厂模式概述1)特点:2)主要角色:3)工作流程:4)优点5)缺点6)适用场景 2、简单工厂模式(静态工厂模式)1) 在简单工厂模式中,有三个主要角色:2) 简单工厂模式的优点包括:3) 简单工厂模式也有一些限制和考虑因素:4) 简单工厂模式适用场景:5) 简单工厂UML类图:6) 代码示例: 3、工厂方法模式1) 在工厂方法模式中,有4个主要角色:2) 工厂方法模式的工作流程

Transformer从零详细解读

Transformer从零详细解读 一、从全局角度概况Transformer ​ 我们把TRM想象为一个黑盒,我们的任务是一个翻译任务,那么我们的输入是中文的“我爱你”,输入经过TRM得到的结果为英文的“I LOVE YOU” ​ 接下来我们对TRM进行细化,我们将TRM分为两个部分,分别为Encoders(编码器)和Decoders(解码器) ​ 在此基础上我们再进一步细化TRM的

A Comprehensive Survey on Graph Neural Networks笔记

一、摘要-Abstract 1、传统的深度学习模型主要处理欧几里得数据(如图像、文本),而图神经网络的出现和发展是为了有效处理和学习非欧几里得域(即图结构数据)的信息。 2、将GNN划分为四类:recurrent GNNs(RecGNN), convolutional GNNs,(GCN), graph autoencoders(GAE), and spatial–temporal GNNs(S

LLM模型:代码讲解Transformer运行原理

视频讲解、获取源码:LLM模型:代码讲解Transformer运行原理(1)_哔哩哔哩_bilibili 1 训练保存模型文件 2 模型推理 3 推理代码 import torchimport tiktokenfrom wutenglan_model import WutenglanModelimport pyttsx3# 设置设备为CUDA(如果可用),否则使用CPU#

逐行讲解Transformer的代码实现和原理讲解:计算交叉熵损失

LLM模型:Transformer代码实现和原理讲解:前馈神经网络_哔哩哔哩_bilibili 1 计算交叉熵目的 计算 loss = F.cross_entropy(input=linear_predictions_reshaped, target=targets_reshaped) 的目的是为了评估模型预测结果与实际标签之间的差距,并提供一个量化指标,用于指导模型的训练过程。具体来说,交叉

深度学习每周学习总结N9:transformer复现

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 目录 多头注意力机制前馈传播位置编码编码层解码层Transformer模型构建使用示例 本文为TR3学习打卡,为了保证记录顺序我这里写为N9 总结: 之前有学习过文本预处理的环节,对文本处理的主要方式有以下三种: 1:词袋模型(one-hot编码) 2:TF-I

RNN发展(RNN/LSTM/GRU/GNMT/transformer/RWKV)

RNN到GRU参考: https://blog.csdn.net/weixin_36378508/article/details/115101779 tRANSFORMERS参考: seq2seq到attention到transformer理解 GNMT 2016年9月 谷歌,基于神经网络的翻译系统(GNMT),并宣称GNMT在多个主要语言对的翻译中将翻译误差降低了55%-85%以上, G

ModuleNotFoundError: No module named ‘diffusers.models.dual_transformer_2d‘解决方法

Python应用运行报错,部分错误信息如下: Traceback (most recent call last): File “\pipelines_ootd\unet_vton_2d_blocks.py”, line 29, in from diffusers.models.dual_transformer_2d import DualTransformer2DModel ModuleNotF

Complex Networks Package for MatLab

http://www.levmuchnik.net/Content/Networks/ComplexNetworksPackage.html 翻译: 复杂网络的MATLAB工具包提供了一个高效、可扩展的框架,用于在MATLAB上的网络研究。 可以帮助描述经验网络的成千上万的节点,生成人工网络,运行鲁棒性实验,测试网络在不同的攻击下的可靠性,模拟任意复杂的传染病的传

Convolutional Neural Networks for Sentence Classification论文解读

基本信息 作者Yoon Kimdoi发表时间2014期刊EMNLP网址https://doi.org/10.48550/arXiv.1408.5882 研究背景 1. What’s known 既往研究已证实 CV领域著名的CNN。 2. What’s new 创新点 将CNN应用于NLP,打破了传统NLP任务主要依赖循环神经网络(RNN)及其变体的局面。 用预训练的词向量(如word2v