本文主要是介绍pytorch中 tensor.repeat()函数用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- torch.tensor.repeat(*size)
- 函数作用
- 举例
官方解释链接
torch.tensor.repeat(*size)
函数作用
用于进行张量数据复制和维度扩展的函数。参数是沿着维度重复的次数。
注意⚠️:
- repeat()函数跟expand()函数不同,repeat()函数重复的时候是根据tensor内的数据。
- 同时,
对于函数中的参数个数一定不能小于tensor的维度个数
。【1】 size中不允许使用负数
,保持不变时使用1
。【2】
举例
参数数量不正确:
res_1.shape # torch.Size([4, 1])
res_2 = res_1.repeat(3)
res_2.shape# 输出
RuntimeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 res_2 = res_1.repeat(3)2 res_2.shapeRuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
参数错误 使用负数:
res_2 = res_1.repeat(-2, 3)
res_2.shape#输出
RuntimeError Traceback (most recent call last)
Cell In[10], line 1
----> 1 res_2 = res_1.repeat(-2,3)2 res_2.shapeRuntimeError: Trying to create tensor with negative dimension -8: [-8, 3]
正确使用例子:参数数量大于等于维度个数,且为正整数
res_1.shape # torch.Size([4, 1])
res_2 = res_1.repeat(2, 3)
res_2.shape#输出
torch.Size([8, 3])
references:
[1] PyTorch中tensor.repeat()的使用
[2] Pytorch 中 expand和repeat
这篇关于pytorch中 tensor.repeat()函数用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!