本文主要是介绍torch.mean,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
mean()函数的参数:dim=0,按行求平均值,返回的形状是(1,列数);dim=1,按列求平均值,返回的形状是(行数,1),默认不设置dim的时候,返回的是所有元素的平均值。
x=torch.arange(12).view(4,3)
'''
注意:在这里使用的时候转一下类型,否则会报RuntimeError: Can only calculate the mean of floating types. Got Long instead.的错误。
查看了一下x元素类型是torch.int64,根据提示添加一句x=x.float()转为tensor.float32就行
'''
x=x.float()
x_mean=torch.mean(x)
x_mean0=torch.mean(x,dim=0,keepdim=True)
x_mean1=torch.mean(x,dim=1,keepdim=True)
print('x:')
print(x)
print('x_mean0:')
print(x_mean0)
print('x_mean1:')
print(x_mean1)
print('x_mean:')
print(x_mean)
查看了一下x元素类型是torch.int64,根据提示添加一句x=x.float()转为tensor.float32就行
输出结果:
x:
tensor([[ 0., 1., 2.],[ 3., 4., 5.],[ 6., 7., 8.],[ 9., 10., 11.]])
x_mean0:
tensor([[4.5000, 5.5000, 6.5000]])
x_mean1:
tensor([[ 1.],[ 4.],[ 7.],[10.]])
x_mean:
tensor(5.5000)
torch.mean().mean()
x=torch.arange(24).view(4,3,2)
x=x.float()
x_mean=torch.mean(x)
print(x)
print(x.mean())
print(x.mean(dim=0,keepdim=True).mean(dim=1,keepdim=True).mean(dim=2,keepdim=True))
print(x.mean(dim=1,keepdim=True).mean(dim=2,keepdim=True))
输出:
tensor([[[ 0., 1.],[ 2., 3.],[ 4., 5.]],[[ 6., 7.],[ 8., 9.],[10., 11.]],[[12., 13.],[14., 15.],[16., 17.]],[[18., 19.],[20., 21.],[22., 23.]]])
tensor(11.5000)
tensor([[[11.5000]]])
tensor([[[ 2.5000]],[[ 8.5000]],[[14.5000]],[[20.5000]]])
torch.mean()和torch.mean(dim=0).mean(dim=1)的区别
以二维为例:torch.mean()返回的是一个标量,而torch.mean(dim=0).mean(dim=1)返回的是一个1行1列的张量,虽然数值相同
x=torch.arange(12).view(4,3)
x=x.float()
x_mean=torch.mean(x)
print(x_mean)
y= x.mean(dim=0, keepdim=True).mean(dim=1, keepdim=True)
print(y)
输出:
tensor(5.5000)
tensor([[5.5000]])
这篇关于torch.mean的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!