本文主要是介绍理解Pytorch中的grid_sample函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 函数签名
- 参数说明
- 示意图
grid_sample是 PyTorch 提供的一个函数,用于执行采样操作,通常用于图像处理。它允许通过给定的采样坐标从输入张量中获取相应的值。采样坐标可以包含小数,这时
grid_sample
会使用插值方法计算出对应的值。
torch.nn.functional.grid_sample
是 PyTorch 中用于从输入特征图中采样的函数。它接受一个输入张量(通常是特征图)和一个包含采样点坐标的网格(grid),并在输入张量中按照网格坐标采样,生成一个新的特征图。
函数签名
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
参数说明
input
:
这是一个形状为(N, C, H_in, W_in)
的 4D 张量,其中N
是批次大小,C
是通道数,H_in
和W_in
分别是输入特征图的高度和宽度。grid
:
这是一个形状为(N, H_out, W_out, 2)
的 4D 张量,表示目标位置的网格。最后一维表示每个位置的(x, y)
坐标,值的范围通常在[-1, 1]
之间,其中-1
对应左/上边界,1
对应右/下边界。mode
:
指定插值方式,有两个选项:'bilinear'
(默认):使用双线性插值。'nearest'
:使用最近邻插值。
padding_mode
:
当采样点超出输入特征图边界时指定填充方式,有三个选项:'zeros'
(默认):超出边界的点填充为 0。'border'
:超出边界的点采用边界值填充。'reflection'
:超出边界的点使用对称填充。
align_corners
(有懂哥可以解释的更清楚一点):True
:采样网格的边缘点直接对齐到原始特征图的像素格上。False
:采样网格的边缘点直接对齐到原始特征图的像素格的角点上。
示意图
这里补充一下,grid经常会生成小数点的值,这些小数点的值是没法作为索引切片的。所以这时候插值的方法就会影响最终的结果了。
grid_sample
提供两种插值方式:
-
mode='bilinear'
(默认):
-
进行双线性插值(bilinear interpolation)。当坐标包含小数时,
grid_sample
会根据周围的像素值来计算出精确的采样结果。这意味着,如果采样点的坐标(即displacement
)落在像素之间,grid_sample
会根据四个相邻像素的值进行加权平均,生成插值结果。 -
具体来说,如果采样点
(x, y)
对应的坐标在(i, j)
和(i+1, j+1)
之间,双线性插值会计算如下:
value = ( 1 − Δ x ) ( 1 − Δ y ) ⋅ V i , j + Δ x ( 1 − Δ y ) ⋅ V i + 1 , j + ( 1 − Δ x ) Δ y ⋅ V i , j + 1 + Δ x Δ y ⋅ V i + 1 , j + 1 \text{value} = (1 - \Delta x)(1 - \Delta y) \cdot V_{i,j} + \Delta x(1 - \Delta y) \cdot V_{i+1,j} + (1 - \Delta x) \Delta y \cdot V_{i,j+1} + \Delta x \Delta y \cdot V_{i+1,j+1} value=(1−Δx)(1−Δy)⋅Vi,j+Δx(1−Δy)⋅Vi+1,j+(1−Δx)Δy⋅Vi,j+1+ΔxΔy⋅Vi+1,j+1
其中, Δ x \Delta x Δx 和 Δ y \Delta y Δy 是坐标的小数部分, V i , j V_{i,j} Vi,j 是像素值。
-
-
mode='nearest'
- 采用最近邻插值(nearest-neighbor interpolation)。如果采样坐标包含小数,
grid_sample
会取最近的整数位置对应的像素值。
- 采用最近邻插值(nearest-neighbor interpolation)。如果采样坐标包含小数,
另外grid的取值范围是 [-1, 1]
,在函数内部会进行尺度的复原:
real ix = THTensor_fastGet4d(grid, n, h, w, 0);
real iy = THTensor_fastGet4d(grid, n, h, w, 1);// normalize ix, iy from [-1, 1] to [0, IH-1] & [0, IW-1]
ix = ((ix + 1) / 2) * (IW-1);
iy = ((iy + 1) / 2) * (IH-1);
这篇关于理解Pytorch中的grid_sample函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!