本文主要是介绍torch.gather——沿特定维度收集数值,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
PyTorch学习笔记:torch.gather——沿特定维度收集数值
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
功能:从输入的数组中,沿指定的dim
维度,利用索引变量index
,将数据索引出来,并且堆叠成一个数组。直观可能不好理解,具体可以见代码案例。
输入:
input
:输入的数组
dim
:指定的维度
index
:索引变量,数据类型需是长整型(int64)
注意:
-
input
和index
具有相同的维数 -
out
和index
具有相同的形状 -
除了
dim
维度,在每个维度上,索引在该维度上的大小要小于等于输入在该维度上的大小,即:
i n d e x . s i z e ( d ) ≤ i n p u t . s i z e ( d ) , d ! = d i m index.size(d)≤input.size(d),\quad d!=dim index.size(d)≤input.size(d),d!=dim
代码案例
一般用法,当在一个维度上进行索引时,以第一维度为例
import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2,3,1,3]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)
输出
以第二维度为例
import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2],[3],[1],[3]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)
输出
当同时在两个维度上进行索引时,以第一维度为例
import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2,3],[2,3,0],[3,0,1]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)
输出
tensor([[ 0, 1, 2, 3, 4],[ 5, 6, 7, 8, 9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 5, 11, 17],[10, 16, 2],[15, 1, 7]])
以第二维度为例
import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2],[2,3],[3,4]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)
输出
tensor([[ 0, 1, 2, 3, 4],[ 5, 6, 7, 8, 9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 1, 2],[ 7, 8],[13, 14]])
官方文档
torch.gather:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather
这篇关于torch.gather——沿特定维度收集数值的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!