PyTorch中文教程 | (8) 空间变换器网络教程

2023-11-09 19:00

本文主要是介绍PyTorch中文教程 | (8) 空间变换器网络教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Github

在本教程中,您将学习如何使用称为空间变换器网络的视觉注意机制来扩充您的网络。你可以在 DeepMind paper阅读有关空间变换器网络的更多内容。

空间变换器网络是对任何空间变换的差异化关注的概括。空间变换器网络(简称STN)允许神经网络学习如何在输入图像上执行空间变换,以增强模型的几何不变性。例如,它可以裁剪感兴趣的区域,缩放并校正图像的方向。它可能是一种有用的机制,因为CNN对于旋转和缩放以及更一般的仿射变换并不是不变的。

关于STN的最棒的事情之一是能够简单地将其插入任何现有的CNN,只需很少的修改。

# License: BSD
# 作者: Ghassen Hamrounifrom __future__ import print_function  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torch.optim as optim  
import torchvision  
from torchvision import datasets, transforms  
import matplotlib.pyplot as plt  
import numpy as np  plt.ion()   # 交互模式

目录

1. 加载数据

2. 空间变换器网络叙述

3. 训练模型

4. 可视化STN结果


1. 加载数据

在这篇文章中,我们尝试了经典的MNIST数据集。使用标准卷积网络增强空间变换器网络

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Training dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)

 

2. 空间变换器网络叙述

空间变换器网络归结为三个主要组成部分:

  • 本地网络(Localisation Network)是常规CNN,其对变换参数进行回归。不会从该数据集中明确地学习转换,而是网络自动学习增强全局准确性的空间变换。
  • 网格生成器( Grid Genator)在输入图像中生成与输出图像中的每个像素相对应的坐标网格。
  • 采样器(Sampler)使用变换的参数并将其应用于输入图像。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)# Spatial transformer localization-networkself.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.MaxPool2d(2, stride=2),nn.ReLU(True),nn.Conv2d(8, 10, kernel_size=5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# Regressor for the 3 * 2 affine matrixself.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32),nn.ReLU(True),nn.Linear(32, 3 * 2))# Initialize the weights/bias with identity transformationself.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))# Spatial transformer network forward functiondef stn(self, x):xs = self.localization(x)xs = xs.view(-1, 10 * 3 * 3)theta = self.fc_loc(xs)theta = theta.view(-1, 2, 3)grid = F.affine_grid(theta, x.size())x = F.grid_sample(x, grid)return xdef forward(self, x):# transform the inputx = self.stn(x)# Perform the usual forward passx = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)model = Net().to(device)

3. 训练模型

现在我们使用SGD(随机梯度下降)算法来训练模型。网络正在以有监督的方式学习分类任务。同时,该模型以端到端的方式自动学习STN。

optimizer = optim.SGD(model.parameters(), lr=0.01)def train(epoch):model.train() #训练模式for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 500 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure STN the performances on MNIST.
#def test():with torch.no_grad():model.eval() #评估模式test_loss = 0correct = 0for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, size_average=False).item()# get the index of the max log-probabilitypred = output.max(1, keepdim=True)[1] #[1]表示取最大值所在的索引correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

 

4. 可视化STN结果

现在,我们将检查我们学习的视觉注意机制的结果。我们定义了一个小辅助函数,以便在训练时可视化变换。

def convert_image_np(inp):"""Convert a Tensor to numpy image."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)return inp# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.def visualize_stn():with torch.no_grad():# Get a batch of training datadata = next(iter(test_loader))[0].to(device)input_tensor = data.cpu()transformed_input_tensor = model.stn(data).cpu()in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))# Plot the results side-by-sidef, axarr = plt.subplots(1, 2)axarr[0].imshow(in_grid)axarr[0].set_title('Dataset Images')axarr[1].imshow(out_grid)axarr[1].set_title('Transformed Images')for epoch in range(1, 20 + 1):train(epoch)test()# Visualize the STN transformation on some input batch
visualize_stn()plt.ioff()
plt.show()

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

这篇关于PyTorch中文教程 | (8) 空间变换器网络教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/378020

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

vscode中文乱码问题,注释,终端,调试乱码一劳永逸版

忘记咋回事突然出现了乱码问题,很多方法都试了,注释乱码解决了,终端又乱码,调试窗口也乱码,最后经过本人不懈努力,终于全部解决了,现在分享给大家我的方法。 乱码的原因是各个地方用的编码格式不统一,所以把他们设成统一的utf8. 1.电脑的编码格式 开始-设置-时间和语言-语言和区域 管理语言设置-更改系统区域设置-勾选Bata版:使用utf8-确定-然后按指示重启 2.vscode

沁恒CH32在MounRiver Studio上环境配置以及使用详细教程

目录 1.  RISC-V简介 2.  CPU架构现状 3.  MounRiver Studio软件下载 4.  MounRiver Studio软件安装 5.  MounRiver Studio软件介绍 6.  创建工程 7.  编译代码 1.  RISC-V简介         RISC就是精简指令集计算机(Reduced Instruction SetCom

前端技术(七)——less 教程

一、less简介 1. less是什么? less是一种动态样式语言,属于css预处理器的范畴,它扩展了CSS语言,增加了变量、Mixin、函数等特性,使CSS 更易维护和扩展LESS 既可以在 客户端 上运行 ,也可以借助Node.js在服务端运行。 less的中文官网:https://lesscss.cn/ 2. less编译工具 koala 官网 http://koala-app.

【Shiro】Shiro 的学习教程(三)之 SpringBoot 集成 Shiro

目录 1、环境准备2、引入 Shiro3、实现认证、退出3.1、使用死数据实现3.2、引入数据库,添加注册功能后端代码前端代码 3.3、MD5、Salt 的认证流程 4.、实现授权4.1、基于角色授权4.2、基于资源授权 5、引入缓存5.1、EhCache 实现缓存5.2、集成 Redis 实现 Shiro 缓存 1、环境准备 新建一个 SpringBoot 工程,引入依赖:

解决Office Word不能切换中文输入

我们在使用WORD的时可能会经常碰到WORD中无法输入中文的情况。因为,虽然我们安装了搜狗输入法,但是到我们在WORD中使用搜狗的输入法的切换中英文的按键的时候会发现根本没有效果,无法将输入法切换成中文的。下面我就介绍一下如何在WORD中把搜狗输入法切换到中文。

Windows环境利用VS2022编译 libvpx 源码教程

libvpx libvpx 是一个开源的视频编码库,由 WebM 项目开发和维护,专门用于 VP8 和 VP9 视频编码格式的编解码处理。它支持高质量的视频压缩,广泛应用于视频会议、在线教育、视频直播服务等多种场景中。libvpx 的特点包括跨平台兼容性、硬件加速支持以及灵活的接口设计,使其可以轻松集成到各种应用程序中。 libvpx 的安装和配置过程相对简单,用户可以从官方网站下载源代码