本文主要是介绍pytorch中repeat()函数理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pytorch中repeat()函数理解
最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解。
情况1:repeat参数个数与tensor维数一致时
a = torch.tensor([[1, 2, 3],[1, 2, 3]])
b = a.repeat(2, 2)
print(b.shape)
运行结果:
即repeat的参数是对应维度的复制个数,上段代码为0维复制两次,1维复制两次,则得到以上运行结果。其余扩展情况依此类推
情况2:repeat参数个数与tensor维数不一致时
# a形状(2,3)
a = torch.tensor([[1, 2, 3],[1, 2, 3]])
# repeat参数比维度多,在扩展前先讲a的形状扩展为(1,2,3)然后复制
b = a.repeat(1, 2, 1)
print(b.shape) # 得到结果torch.Size([1, 4, 3])
首先在第0维扩展一个维度,维数为1,然后按照参数指定的次数进行复制
# a形状(2,3)
a = torch.tensor([[1, 2, 3],[1, 2, 3]])
# repeat参数比维度多,在扩展前先讲a的形状扩展为(1,2,3)然后复制
b = a.repeat(1, 1, 2)
print(b.shape) # 得到结果torch.Size([1, 2, 6])
# a形状(2,3)
a = torch.tensor([[1, 2, 3],[1, 2, 3]])
# repeat参数比维度多,在扩展前先讲a的形状扩展为(1,2,3)然后复制
b = a.repeat(2, 1, 1)
print(b.shape) # 得到结果torch.Size([2, 2, 3])
以上演示可以看到,在参数个数大于原tensor维度个数时,总是先在第0维扩展一个维数为1的维度,然后按照参数指定的复制次数进行复制。计算输出的形状时,可以按照 对应参数*对应维度维数 得到结果
这篇关于pytorch中repeat()函数理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!