本文主要是介绍pytorch中的维度变换操作性质大总结:view, reshape, transpose, permute,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在深度学习中,张量的维度变换是很重要的操作。在pytorch中,有四个用于维度变换的函数,view
, reshape
, transpose
, permute
。其中view
, reshape
都用于改变张量的形状,transpose
, permute
都用于重新排列张量的维度,但它们的功能和使用场景有所不同,下面将进行详细介绍,并给出测试验证代码,经过全面的了解,我们才能知道如何正确的使用这四个函数。
这里写目录标题
- 1. torch.Tensor.view
- 2. torch.reshape
- 3. torch.transpose
- 4. torch.permute
- 5. torch.transpose与torch.permute的性质与原理
1. torch.Tensor.view
文档:Doc
- view 方法返回一个新的张量,具有与原始张量相同的数据,但改变了形状。所以view返回的是原始数据的一个新尺寸的视图,这也就是为什么叫做view。
输出:import torch # 创建一个2x6的张量 x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]]) # 将其调整为3x4的形状 y = x.view(3, 4) print("x shape: ", x.shape) print("y shape: ", y.shape) # 判断新旧张量是否数据是相同的 print(x.data_ptr() == y.data_ptr())
x shape: torch.Size([2, 6]) y shape: torch.Size([3, 4]) True
- view 要求原始张量是连续的(即在内存中是按顺序存储的),否则会抛出错误。
报错输出:import torch # 创建一个2x6的张量 x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]]) # 将向量转置,此时x不再是连续的 x = x.T # 在不连续的张量上进行view将会报错 y = x.view(3, 4)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
- 如果张量不是连续的,可以使用 contiguous 方法先将其转换为连续的。
2. torch.reshape
文档:Doc
- reshape不要求原始张量是连续的
- 如果原始张量是连续的,那么实现的功能和view一样
- 如果原始张量不是连续的,那么reshape就是tensor.contigous().view(),也就是会重新开辟一块内存空间,拷贝原始张量,使其连续;
- 在连续张量上,view 和 reshape 性能相同。在非连续张量上,reshape 可能会稍慢一些,因为它可能需要创建新的连续张量。
输出:import torch # 创建一个2x6的张量 x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]]) # 将向量转置,此时x不再是连续的 x = x.T # 在不连续的张量上可以进行reshape y = x.reshape(3, 4) print("x shape: ", x.shape) print("y shape: ", y.shape) # 但reshape返回的是新的内存中的张量 print(x.data_ptr() == y.data_ptr())
x shape: torch.Size([6, 2]) y shape: torch.Size([3, 4]) False
3. torch.transpose
Doc
- 功能:仅用于交换两个维度。它接受两个维度参数,分别表示要交换的维度。
- 不改变数据:不会改变数据本身,只是改变数据的视图(即不复制数据)。
- 生成的新张量也通常不是连续的。它只是交换两个维度的顺序,不改变数据在内存中的实际存储顺序。
- 对原始张量是不是连续的没有要求
输出:import torch # 创建一个3x4的张量 x = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]]) print(x.is_contiguous()) # 交换第一个和第二个维度 y = torch.transpose(x, 0, 1) print(y.is_contiguous()) print("x shape: ", x.shape) print("y shape: ", y.shape) print(x.data_ptr() == y.data_ptr())
True False x shape: torch.Size([3, 4]) y shape: torch.Size([4, 3]) True
4. torch.permute
Doc
- 可以重新排列任意数量的维度,适用于复杂的维度变换。接受一个shape元组作为参数
- 不改变数据:不会改变数据本身,只是改变数据的视图(即不复制数据)
- 生成的新张量通常不是连续的。因为它仅改变维度顺序,不改变数据在内存中的实际顺序。
- 对原始张量是不是连续的没有要求
输出:import torch# 创建一个3x4x5的张量x = torch.randn(3, 4, 5)# 将其第一个和第二个维度交换y = torch.permute(x, (1, 0, 2))print(y.is_contiguous())print(x.data_ptr() == y.data_ptr())print(y.size()) # 输出:torch.Size([4, 3, 5])
False True torch.Size([4, 3, 5])
5. torch.transpose与torch.permute的性质与原理
这两者的功能和各方面的性质基本是相同的,只是一个只能交换两个维度,一个能进行更复杂的维度排列。他们的原理是:transpose 和 permute 通过改变张量的 strides(步幅)来重新排列维度。strides 定义了在内存中沿着每个维度移动的步长。它们不改变张量的数据,只是改变了访问数据的方式。因此,这些操作可以应用于任何张量,无论它们是否连续。
这篇关于pytorch中的维度变换操作性质大总结:view, reshape, transpose, permute的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!