本文主要是介绍【PyTorch】torch.mean(), dim=0, dim=1 详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
创建一个tensor,这个tensor是一个元素类型为浮点型的2维数组
import torchs = torch.arange(6,dtype=float).reshape((2,3))
print(s)
print(s.shape)# 查看tensor的形状tensor([[0., 1., 2.],[3., 4., 5.]], dtype=torch.float64)
torch.Size([2, 3])
dim属性的全称是dimension,表示维度。dim=0为第0个维度,代表行。
对于torch.mean(s,dim=0),表示跨行求平均。
得到的结果是一个向量,分别对应于 1.5=(0.0+3.0)/2, 2.5=(1.0+4.0)/2, 3.5=(2.0+5.0)/2
s1 = torch.mean(s, dim=0)
print(s1)tensor([1.5000, 2.5000, 3.5000], dtype=torch.float64)
同理,对于dim=1为第一个维度,代表列。
对于torch.sum(s,dim=1),表示跨列求平均。
得到的结果同样是一个向量,分别对应于 1.0=(0.0+1.0+2.0)/3, 4.0=(3.0+4.0+5.0)/3
s2 = torch.mean(s, dim=1)
print(s2)tensor([1., 4.], dtype=torch
这篇关于【PyTorch】torch.mean(), dim=0, dim=1 详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!