通俗易懂的Spatial Transformer Networks(STN)(一)

2023-12-22 15:58

本文主要是介绍通俗易懂的Spatial Transformer Networks(STN)(一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

导读

pytorch为了方便实现STN,里面封装了affine_gridgrid_sample两个高级API。对STN不太了解的同学可以参考这篇详细解读Spatial Transformer Networks(STN)

其实STN的作用是想让CNN具备平移、旋转、缩放、剪切不变性,虽然说CNN中的Pooling可以让网络具备一点平移不变性,但这毕竟是隐性的,如果能让网络直接具备这样的能力岂不是更好。

如果对图像处理有了解的同学也许听过仿射变换这个名词,我们只需要通过变换矩阵 θ \theta θ(由6个参数组成)就能实现上面的这些功能,如果对仿射变换不了解的同学可以参考我的这篇一文搞懂仿射变换

STN也是因为受到这个启发而诞生的,那么我们如何将这种能力嵌入到CNN中呢?这便是STN需要解决的问题

STN简介

在这里插入图片描述

上面引用的文章中已经详细介绍了STN网络,我这里总结概括一下

  • Localisation net

Localisation net模块通过CNN提取图像的特征来预测变换矩阵 θ \theta θ

  • Grid generator

Grid generator模块就是利用Localisation net模块回归出来的 θ \theta θ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置

例如:如果此时 θ \theta θ参数功能是实现图片的平移变换(向右平移1,),输入图片上的坐标(1,1),那对应输出图片上的坐标的(2,1),也就是说输入图片上(1,1)对应的像素值等于输出图片上(2,1)对应的像素值。在变换的时候必然会遇到当输入图片的位置变换到输出图片上是如果位置出现小数怎么办?

  • Sampler

Sampler就是用来解决Grid generator模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation),下面我们来介绍一下这个算法
在这里插入图片描述
上图中 ( x , y ) (x,y) (x,y)是变换后输出图像上的位置,带下标的坐标位置表示的是与 ( x , y ) (x,y) (x,y)在输入图像对应的四个相邻的坐标。上面的坐标满足下面的关系
x 1 − x 0 = 1 y 1 − y 0 = 1 x_1-x_0 = 1\\ y1-y_0 = 1 x1x0=1y1y0=1
根据双线性插值的原则距离相邻点近的坐标占的比重越大,所以 ( x , y ) (x,y) (x,y)对应的像素值为,我们用 f ( x , y ) f(x,y) f(x,y)表示点 ( x , y ) (x,y) (x,y)所对应的像素值
f ( x , y ) = ( x 1 − x ) ( y 1 − y ) f ( x 0 , y 0 ) + ( x − x 0 ) ( y 1 − y ) f ( x 1 , y 0 ) = + ( x − x 0 ) ( y − y 0 ) f ( x 1 , y 1 ) + ( x 1 − x ) ( y − y 0 ) f ( x 0 , y 1 ) \begin{aligned} f(x,y) &= (x_1-x)(y1-y)f(x_0,y_0)+(x-x_0)(y_1-y)f(x_1,y_0)\\ &=+(x-x_0)(y-y_0)f(x_1,y_1)+(x_1-x)(y-y_0)f(x_0,y_1) \end{aligned} f(x,y)=(x1x)(y1y)f(x0,y0)+(xx0)(y1y)f(x1,y0)=+(xx0)(yy0)f(x1,y1)+(x1x)(yy0)f(x0,y1)

STN层的实现

  • pytorch的实现

通过pytorchaffine_gridgrid_sample可以很容易实现STN的后两个模块

from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),img_tensor.unsqueeze(0).size(),align_corners=True)
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),grid,align_corners=True)plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")plt.show()

在这里插入图片描述

  • numpy的实现

我们通过numpy来实现STN的后两个模块,来帮助大家更好的理解STN

class Grid_sample(object):def affine_grid(self,theta,img_size):if len(img_size) != 2:assert("img_size size must is 2")num_batch = np.shape(theta)[0]img_w,img_h = img_size#将图片位置归一化到(-1,1)x = np.linspace(-1.0,1.0,img_w)y = np.linspace(-1.0,1.0,img_h)#组合x和y获取到图片的位置坐标x_t,y_t = np.meshgrid(x,y)x_t_flat = np.reshape(x_t,[-1])y_t_flat = np.reshape(y_t,[-1])#创建一个图片的位置数组ones = np.ones_like(x_t_flat)sampling_grid = np.stack([x_t_flat,y_t_flat,ones])sampling_grid = np.expand_dims(sampling_grid,axis=0)sampling_grid = np.tile(sampling_grid,np.stack([num_batch,1,1]))#计算变换后的图片位置batch_grids = np.matmul(theta,sampling_grid)batch_grids = np.reshape(batch_grids,[num_batch,2,img_h,img_w])return batch_gridsdef bilinear_sampler(self,img,batch_grids):if (batch_grids.shape) != 4:assert("batch_grids shape is must equal 4")#获取变换后图片位置的x和y轴的坐标位置x = batch_grids[:, 0, :, :]y = batch_grids[:, 1, :, :]img_w,img_h = img.shape[:2]max_x = img_w - 1max_y = img_h - 1#将变换后的坐标位置固定到(0,w/h-1)x = 0.5 * ((x+1.0)*(max_x-1))y = 0.5 * ((y+1.0)*(max_y-1))#将坐标位置取整,便于从输入图片中获取位置对应的像素值x0 = np.floor(x).astype(np.int)x1 = x0 + 1y0 = np.floor(y).astype(np.int)y1 = y0 + 1#防止坐标越界x0 = np.clip(x0,0,max_x)x1 = np.clip(x1,0,max_x)y0 = np.clip(y0,0,max_y)y1 = np.clip(y1,0,max_y)#根据坐标位置,取像素值Ia = img[y0,x0,:]Ib = img[y1,x0,:]Ic = img[y0,x1,:]Id = img[y1,x1,:]wa = np.expand_dims((x1-x)*(y1-y),axis=3)wb = np.expand_dims((x1-x)*(y-y0),axis=3)wc = np.expand_dims((x-x0)*(y1-y),axis=3)wd = np.expand_dims((x-x0)*(y-y0),axis=3)#利用双线性插值计算变换后的像素值out = wa*Ia + wb*Ib + wc*Ic + wd*Idreturn outgrid_sampler = Grid_sample()
img = np.array(Image.open("img/test.jpg"))
img_h,img_w = img.shape[:2]
theta = np.array([[[1, 0, 0.1], [0, 1, 0.2]]],dtype=np.float)
theta = np.expand_dims(theta,axis=0)batch_grids = grid_sampler.affine_grid(theta,(img_w,img_h))
out = grid_sampler.bilinear_sampler(img,batch_grids)plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(np.array(img))
plt.title("original image")plt.subplot(1, 2, 2)
plt.imshow(out[0].astype(np.uint8))
plt.title("stn transform image")plt.show()

在这里插入图片描述
下一篇文章我们介绍如何将STN模块插入到CNN中

这篇关于通俗易懂的Spatial Transformer Networks(STN)(一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/524562

相关文章

设计模式之工厂模式(通俗易懂--代码辅助理解【Java版】)

文章目录 1、工厂模式概述1)特点:2)主要角色:3)工作流程:4)优点5)缺点6)适用场景 2、简单工厂模式(静态工厂模式)1) 在简单工厂模式中,有三个主要角色:2) 简单工厂模式的优点包括:3) 简单工厂模式也有一些限制和考虑因素:4) 简单工厂模式适用场景:5) 简单工厂UML类图:6) 代码示例: 3、工厂方法模式1) 在工厂方法模式中,有4个主要角色:2) 工厂方法模式的工作流程

Transformer从零详细解读

Transformer从零详细解读 一、从全局角度概况Transformer ​ 我们把TRM想象为一个黑盒,我们的任务是一个翻译任务,那么我们的输入是中文的“我爱你”,输入经过TRM得到的结果为英文的“I LOVE YOU” ​ 接下来我们对TRM进行细化,我们将TRM分为两个部分,分别为Encoders(编码器)和Decoders(解码器) ​ 在此基础上我们再进一步细化TRM的

A Comprehensive Survey on Graph Neural Networks笔记

一、摘要-Abstract 1、传统的深度学习模型主要处理欧几里得数据(如图像、文本),而图神经网络的出现和发展是为了有效处理和学习非欧几里得域(即图结构数据)的信息。 2、将GNN划分为四类:recurrent GNNs(RecGNN), convolutional GNNs,(GCN), graph autoencoders(GAE), and spatial–temporal GNNs(S

LLM模型:代码讲解Transformer运行原理

视频讲解、获取源码:LLM模型:代码讲解Transformer运行原理(1)_哔哩哔哩_bilibili 1 训练保存模型文件 2 模型推理 3 推理代码 import torchimport tiktokenfrom wutenglan_model import WutenglanModelimport pyttsx3# 设置设备为CUDA(如果可用),否则使用CPU#

逐行讲解Transformer的代码实现和原理讲解:计算交叉熵损失

LLM模型:Transformer代码实现和原理讲解:前馈神经网络_哔哩哔哩_bilibili 1 计算交叉熵目的 计算 loss = F.cross_entropy(input=linear_predictions_reshaped, target=targets_reshaped) 的目的是为了评估模型预测结果与实际标签之间的差距,并提供一个量化指标,用于指导模型的训练过程。具体来说,交叉

深度学习每周学习总结N9:transformer复现

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 目录 多头注意力机制前馈传播位置编码编码层解码层Transformer模型构建使用示例 本文为TR3学习打卡,为了保证记录顺序我这里写为N9 总结: 之前有学习过文本预处理的环节,对文本处理的主要方式有以下三种: 1:词袋模型(one-hot编码) 2:TF-I

RNN发展(RNN/LSTM/GRU/GNMT/transformer/RWKV)

RNN到GRU参考: https://blog.csdn.net/weixin_36378508/article/details/115101779 tRANSFORMERS参考: seq2seq到attention到transformer理解 GNMT 2016年9月 谷歌,基于神经网络的翻译系统(GNMT),并宣称GNMT在多个主要语言对的翻译中将翻译误差降低了55%-85%以上, G

ModuleNotFoundError: No module named ‘diffusers.models.dual_transformer_2d‘解决方法

Python应用运行报错,部分错误信息如下: Traceback (most recent call last): File “\pipelines_ootd\unet_vton_2d_blocks.py”, line 29, in from diffusers.models.dual_transformer_2d import DualTransformer2DModel ModuleNotF

Complex Networks Package for MatLab

http://www.levmuchnik.net/Content/Networks/ComplexNetworksPackage.html 翻译: 复杂网络的MATLAB工具包提供了一个高效、可扩展的框架,用于在MATLAB上的网络研究。 可以帮助描述经验网络的成千上万的节点,生成人工网络,运行鲁棒性实验,测试网络在不同的攻击下的可靠性,模拟任意复杂的传染病的传

Convolutional Neural Networks for Sentence Classification论文解读

基本信息 作者Yoon Kimdoi发表时间2014期刊EMNLP网址https://doi.org/10.48550/arXiv.1408.5882 研究背景 1. What’s known 既往研究已证实 CV领域著名的CNN。 2. What’s new 创新点 将CNN应用于NLP,打破了传统NLP任务主要依赖循环神经网络(RNN)及其变体的局面。 用预训练的词向量(如word2v