本文主要是介绍torch.gather()取每行中不同列的元素,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pytorch取每行中不同列的元素
import torch
scores = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]
])
label=torch.LongTensor([ [0],[1],[2] ])
ans = scores.gather(1, label)
print(ans)
常用场合:信息检索或者推荐系统模型中计算指标要获得item 或者 document经过模型排序后的结果
对item算score,然后要对score算排名,最后根据排名取出前十个item
#item为id,在model里通过item id取出 embedding
score = model(item) # item, score shape [batch, 100]
ranks = torch,argsort(score, dim=-1, descending=True) # 降序排列
ranks = ranks[:, :10] #对每个样本取前十
ids = item.gather(1, ranks) # [batch, 10] ids 为item经过model计算分数后排序的前十物品
这篇关于torch.gather()取每行中不同列的元素的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!