本文主要是介绍PyTorch中torch.squeeze() 和torch.unsqueeze()用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
squeeze的用法主要就是对数据的维度进行压缩或者解压
- torch.squeeze() 对数据的维度进行压缩
https://pytorch.org/docs/stable/torch.html?highlight=torch%20squeeze#torch.squeeze
(a) 去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉。不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。
(b) b=torch.squeeze(a,N) a中去掉指定的定的维数的维度。
- torch.unsqueeze()对数据维度进行扩充
https://pytorch.org/docs/stable/torch.html?highlight=torch%20squeeze#torch.unsqueeze
(a) 给指定位置加上维数为一的维度,比如原本有个三行的数据(3),在0的位置加了一维就变成一行三列(1,3)。a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度。
(b) b=torch.squeeze(a,N) a就是在a中指定位置N加上一个维数为1的维度
torch.unsqueeze 举例说明
x = torch.tensor([1, 2, 3, 4])
print('x.shape = ',x.shape)
print('x = ',x)
x.shape = torch.Size([4])
x = tensor([1, 2, 3, 4])
x1 = torch.unsqueeze(x, 0) # x 在0的位置加上一个维度
print('x1.shape = ',x1.shape)
print('x1 = ',x1)
x1.shape = torch.Size([1, 4])
x1 = tensor([[1, 2, 3, 4]])
x2 = torch.unsqueeze(x, 1)
print('x2.shape = ',x2.shape)
print('x2 = ',x2)
x2.shape = torch.Size([4, 1])
x2 = tensor([[1],[2],[3],[4]])
torch.squeeze 举例说明
x = torch.zeros(2, 1, 2, 1, 2)
print('x.shape = ',x.shape)
x.shape = torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x) # 取出所有维度1
print('y.shape = ',y.shape)
y.shape = torch.Size([2, 2, 2])
y = torch.squeeze(x, 1)#去除指定位置1
print('y.shape = ',y.shape)
y.shape = torch.Size([2, 2, 1, 2])
这篇关于PyTorch中torch.squeeze() 和torch.unsqueeze()用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!