本文主要是介绍pytorch 常用函数 max ,eq,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
max找出tensor 的行或者列最大的值:
找出每行的最大值:
import torchoutputs=torch.FloatTensor([[1],[2],[3]])print(torch.max(outputs.data,1))
输出:
(tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))
找出每列的最大值:
import torchoutputs=torch.FloatTensor([[1],[2],[3]])print(torch.max(outputs.data,0))
输出结果:
(tensor([ 3.]), tensor([ 2]))
Tensor比较eq相等:
import torchoutputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))
输出结果:
tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)
使用sum() 统计相等的个数:
import torchoutputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())
输出结果:
tensor(2)
这篇关于pytorch 常用函数 max ,eq的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!