本文主要是介绍Pytorch基础:torch.expand() 和 torch.repeat(),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在torch中,如果要改变
某一个tensor的维度
,可以利用view
、expand
、repeat
、transpose
和permute
等方法,这里对这些方法的一些容易混淆的地方做个总结。
expand和repeat函数是pytorch中常用于进行张量数据复制
和维度扩展
的函数,但其工作机制差别很大
,本文对这两个函数进行对比。
1. torch.expand()
作用
: expand()函数可以将张量广播到新的形状。注意
: 只能对维度值为1
的维度进行扩展
,无需扩展
的维度,维度值不变,对应位置可写上原始维度大小
或直接写作-1
;且扩展的Tensor不会分配新的内存
,只是原来的基础上创建新的视图并返回,返回的张量内存
是不连续的
。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous
函数。
expand函数用于将张量中单数维的数据扩展到指定的size。
首先解释下什么叫单数维
(singleton dimensions),张量在某个维度上的size为1
,则称为单数维
。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)
在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上
。
参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度
,对应位置可写上原始维度
大小或直接写作-1
。
expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。
import torcha = torch.tensor([1, 0, 2]) # a -> torch.Size([3])
b1 = a.expand(2, -1) # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],[1, 0, 2]])
'''a = torch.tensor([[1], [0], [2]]) # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2) # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为 tensor([[1, 1],[0, 0],[2, 2]])
'''a = torch.Tensor([[1, 2, 3]]) # a -> torch.Size([1, 3])
b3 = a.expand(4, 3) # 也可写为a.expand(4, -1) 对于某一个维度上的值为1的维度,# 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor([[1.,2.,3.],[1.,2.,3.],[1.,2.,3.],[1.,2.,3.]]
)
'''a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # a -> torch.Size([2, 3])
b4 = a.expand(4, 6) # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match
the existing size (3) at non-singleton dimension 1.
'''b5 = a.expand(1, 2, 3) # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3) # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''b7 = a.expand(2, 3, 2) # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the
existing size (3) at non-singleton dimension 2.
'''b8 = a.expand(2, -1, -1) # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''# expand返回的张量与原版张量具有相同内存地址
print(b8.storage()) # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''
1.1 expand_as
可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。
import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b) # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
# [1, 0, 2]])
2 tensor.repeat()
沿着特定维度扩展张量,并返回扩展后的张量
- 作用:和expand()作用类似,均是将tensor广播到新的形状。
- 注意:
不允许使用维度-1,1即为不变
import torchif __name__ == '__main__':x = torch.rand(2, 3)y1 = x.repeat(4, 2)print(y1.shape) # torch.Size([8, 6])
3. 两者内存占用的区别
-
torch.expand 不会占用额外空间
,只是在存在的张量上创建一个新的视图 -
torch.repeat
和 torch.expand 不同,它是拷贝了数据,会占用额外的空间
示例如下:
import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)y2 = x.repeat(2, 3)print(x.storage().data_ptr(), y1.storage().data_ptr()) # 52364352 52364352print(x.storage().data_ptr(), y2.storage().data_ptr()) # 52364352 8852096
这篇关于Pytorch基础:torch.expand() 和 torch.repeat()的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!