pytorch中的维度变换操作性质大总结:view, reshape, transpose, permute

本文主要是介绍pytorch中的维度变换操作性质大总结:view, reshape, transpose, permute,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在深度学习中,张量的维度变换是很重要的操作。在pytorch中,有四个用于维度变换的函数,view, reshape, transpose, permute。其中view, reshape都用于改变张量的形状,transpose, permute都用于重新排列张量的维度,但它们的功能和使用场景有所不同,下面将进行详细介绍,并给出测试验证代码,经过全面的了解,我们才能知道如何正确的使用这四个函数。

这里写目录标题

    • 1. torch.Tensor.view
    • 2. torch.reshape
    • 3. torch.transpose
    • 4. torch.permute
    • 5. torch.transpose与torch.permute的性质与原理

1. torch.Tensor.view

文档:Doc

  • view 方法返回一个新的张量,具有与原始张量相同的数据,但改变了形状。所以view返回的是原始数据的一个新尺寸的视图,这也就是为什么叫做view。
    import torch
    # 创建一个2x6的张量
    x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]])
    # 将其调整为3x4的形状
    y = x.view(3, 4)
    print("x shape: ", x.shape)
    print("y shape: ", y.shape)
    # 判断新旧张量是否数据是相同的
    print(x.data_ptr() == y.data_ptr())
    
    输出:
    x shape:  torch.Size([2, 6])
    y shape:  torch.Size([3, 4])
    True
    
  • view 要求原始张量是连续的(即在内存中是按顺序存储的),否则会抛出错误。
    import torch
    # 创建一个2x6的张量
    x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]])
    # 将向量转置,此时x不再是连续的
    x = x.T
    # 在不连续的张量上进行view将会报错
    y = x.view(3, 4)
    
    报错输出:
    RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
    
  • 如果张量不是连续的,可以使用 contiguous 方法先将其转换为连续的。

2. torch.reshape

文档:Doc

  • reshape不要求原始张量是连续的
  • 如果原始张量是连续的,那么实现的功能和view一样
  • 如果原始张量不是连续的,那么reshape就是tensor.contigous().view(),也就是会重新开辟一块内存空间,拷贝原始张量,使其连续;
  • 在连续张量上,view 和 reshape 性能相同。在非连续张量上,reshape 可能会稍慢一些,因为它可能需要创建新的连续张量。
    import torch
    # 创建一个2x6的张量
    x = torch.tensor([[1, 2, 3, 4, 5, 6],[7, 8, 9, 10, 11, 12]])
    # 将向量转置,此时x不再是连续的
    x = x.T
    # 在不连续的张量上可以进行reshape
    y = x.reshape(3, 4)
    print("x shape: ", x.shape)
    print("y shape: ", y.shape)
    # 但reshape返回的是新的内存中的张量
    print(x.data_ptr() == y.data_ptr())
    
    输出:
    x shape:  torch.Size([6, 2])
    y shape:  torch.Size([3, 4])
    False
    

3. torch.transpose

Doc

  • 功能:仅用于交换两个维度。它接受两个维度参数,分别表示要交换的维度。
  • 不改变数据:不会改变数据本身,只是改变数据的视图(即不复制数据)。
  • 生成的新张量也通常不是连续的。它只是交换两个维度的顺序,不改变数据在内存中的实际存储顺序。
  • 对原始张量是不是连续的没有要求
    import torch
    # 创建一个3x4的张量
    x = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])
    print(x.is_contiguous())
    # 交换第一个和第二个维度
    y = torch.transpose(x, 0, 1)
    print(y.is_contiguous())
    print("x shape: ", x.shape)
    print("y shape: ", y.shape)
    print(x.data_ptr() == y.data_ptr())
    
    输出:
    True
    False
    x shape:  torch.Size([3, 4])
    y shape:  torch.Size([4, 3])
    True
    

4. torch.permute

Doc

  • 可以重新排列任意数量的维度,适用于复杂的维度变换。接受一个shape元组作为参数
  • 不改变数据:不会改变数据本身,只是改变数据的视图(即不复制数据)
  • 生成的新张量通常不是连续的。因为它仅改变维度顺序,不改变数据在内存中的实际顺序。
  • 对原始张量是不是连续的没有要求
    	import torch# 创建一个3x4x5的张量x = torch.randn(3, 4, 5)# 将其第一个和第二个维度交换y = torch.permute(x, (1, 0, 2))print(y.is_contiguous())print(x.data_ptr() == y.data_ptr())print(y.size())  # 输出:torch.Size([4, 3, 5])
    
    输出:
    False
    True
    torch.Size([4, 3, 5])
    

5. torch.transpose与torch.permute的性质与原理

这两者的功能和各方面的性质基本是相同的,只是一个只能交换两个维度,一个能进行更复杂的维度排列。他们的原理是:transpose 和 permute 通过改变张量的 strides(步幅)来重新排列维度。strides 定义了在内存中沿着每个维度移动的步长。它们不改变张量的数据,只是改变了访问数据的方式。因此,这些操作可以应用于任何张量,无论它们是否连续。

这篇关于pytorch中的维度变换操作性质大总结:view, reshape, transpose, permute的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1037317

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

git使用的说明总结

Git使用说明 下载安装(下载地址) macOS: Git - Downloading macOS Windows: Git - Downloading Windows Linux/Unix: Git (git-scm.com) 创建新仓库 本地创建新仓库:创建新文件夹,进入文件夹目录,执行指令 git init ,用以创建新的git 克隆仓库 执行指令用以创建一个本地仓库的

二分最大匹配总结

HDU 2444  黑白染色 ,二分图判定 const int maxn = 208 ;vector<int> g[maxn] ;int n ;bool vis[maxn] ;int match[maxn] ;;int color[maxn] ;int setcolor(int u , int c){color[u] = c ;for(vector<int>::iter

整数Hash散列总结

方法:    step1  :线性探测  step2 散列   当 h(k)位置已经存储有元素的时候,依次探查(h(k)+i) mod S, i=1,2,3…,直到找到空的存储单元为止。其中,S为 数组长度。 HDU 1496   a*x1^2+b*x2^2+c*x3^2+d*x4^2=0 。 x在 [-100,100] 解的个数  const int MaxN = 3000

状态dp总结

zoj 3631  N 个数中选若干数和(只能选一次)<=M 的最大值 const int Max_N = 38 ;int a[1<<16] , b[1<<16] , x[Max_N] , e[Max_N] ;void GetNum(int g[] , int n , int s[] , int &m){ int i , j , t ;m = 0 ;for(i = 0 ;

go基础知识归纳总结

无缓冲的 channel 和有缓冲的 channel 的区别? 在 Go 语言中,channel 是用来在 goroutines 之间传递数据的主要机制。它们有两种类型:无缓冲的 channel 和有缓冲的 channel。 无缓冲的 channel 行为:无缓冲的 channel 是一种同步的通信方式,发送和接收必须同时发生。如果一个 goroutine 试图通过无缓冲 channel

9.8javaweb项目总结

1.主界面用户信息显示 登录成功后,将用户信息存储在记录在 localStorage中,然后进入界面之前通过js来渲染主界面 存储用户信息 将用户信息渲染在主界面上,并且头像设置跳转,到个人资料界面 这里数据库中还没有设置相关信息 2.模糊查找 检测输入框是否有变更,有的话调用方法,进行查找 发送检测请求,然后接收的时候设置最多显示四个类似的搜索结果

动手学深度学习【数据操作+数据预处理】

import osos.makedirs(os.path.join('.', 'data'), exist_ok=True)data_file = os.path.join('.', 'data', 'house_tiny.csv')with open(data_file, 'w') as f:f.write('NumRooms,Alley,Price\n') # 列名f.write('NA

线程的四种操作

所属专栏:Java学习        1. 线程的开启 start和run的区别: run:描述了线程要执行的任务,也可以称为线程的入口 start:调用系统函数,真正的在系统内核中创建线程(创建PCB,加入到链表中),此处的start会根据不同的系统,分别调用不同的api,创建好之后的线程,再单独去执行run(所以说,start的本质是调用系统api,系统的api