本文主要是介绍【pytorch 中 torch.max 和 torch.argmax 的区别】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
torch.max 和 torch.argmax 的区别
1.torch.max
torch.max(input, dim, max=None, max_indices=None, keepdim=False) -->> (Tensor, LongTensor)
作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。
应用举例
例1——返回相应维度上的最大值,并返回最大值的位置索引
a = torch.randn(4, 4)
a
>tensor([[-1.2360, -0.2942, -0.1222, 0.8475],[ 1.1949, -1.1127, -2.2379, -0.6702],[ 1.5717, -0.9207, 0.1297, -1.8768],[-0.6172, 1.0036, -0.6060, -0.2432]])
torch.max(a, 1)
>torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]),
indices=tensor([3, 0, 0, 1]))
例2——如果max的参数只有一个tensor,则返回该tensor里所有值中的最大值。
a = torch.randn(4, 4)
a
>tensor([[ 0.4997, 0.8054, 0.1761, 0.3055],[-1.2234, 0.3823, 0.2266, -2.9062],[ 0.4390, -1.0142, -0.5314, -1.7095],[-0.2296, -0.4230, -0.7446, -0.0828]])
torch.max(a)
>tensor(0.8054)
例3——如果max的参数是两个相同shape的tensor,则返回两tensor元素对应位置的最大值的新tensor
a = torch.randint(2, 10,(6,4))
a
>tensor([[8, 7, 3, 5],[2, 8, 3, 4],[3, 2, 5, 5],[4, 7, 5, 2],[2, 9, 3, 8],[4, 4, 2, 2]])
b = torch.randint(2, 10,(6,4))
b
>tensor([[9, 8, 9, 2],[4, 3, 3, 4],[6, 9, 2, 7],[4, 3, 2, 7],[4, 4, 9, 2],[8, 2, 6, 2]])
torch.max(a, b)
>tensor([[9, 8, 9, 5],[4, 8, 3, 4],[6, 9, 5, 7],[4, 7, 5, 7],[4, 9, 9, 8],[8, 4, 6, 2]])
2. torch.argmax
函数定义
torch.argmax(input, dim, keepdim=False) → LongTensor
作用:返回输入张量中指定维度的最大值的索引。
应用举例:
例1——指定维度:返回相应维度最大值的索引
a = torch.randn(4, 4)
a
>tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],[-0.7401, -0.8805, -0.3402, -1.1936],[ 0.4907, -1.3948, -1.0691, -0.3132],[-1.6092, 0.5419, -0.2993, 0.3195]])
torch.argmax(a, dim=1)
>tensor([ 0, 2, 0, 1])
例2——不指定维度,返回整体上最大值的序号
a = torch.randint(9,(3, 3))
a
>tensor([[5, 2, 2],[7, 2, 0],[8, 0, 6]])
torch.argmax(a)
>tensor(6)
3.torch.min
用法同max
4.torch.argmin
用法同argmax
这篇关于【pytorch 中 torch.max 和 torch.argmax 的区别】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!