torch.gather用法详解

2024-04-20 20:52
文章标签 详解 用法 torch gather

本文主要是介绍torch.gather用法详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.gather是PyTorch中的一个函数,用于从源张量中按照指定的索引张量来收集数据。

基本语法如下,

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • input:输入源张量
  • dim:要收集数据的维度
  • index:索引
  • sparse_grad:如果为True,则gather()在反向传播时会返回稀疏梯度
  • out:输出张量,形状与index相同

用法讲解

假设有以下输入张量x,

x = torch.tensor([[[ 1,  2],[ 3,  4]],[[ 5,  6],[ 7,  8]],[[ 9, 10],[11, 12]]
])

假设有以下索引index,

index = torch.tensor([[[0, 1],[1, 0]],[[1, 0],[0, 1]],[[0, 1],[1, 0]]
])

index的索引及里面的元素的对应关系如下,

index[0, 0, 0] = 0
index[0, 0, 1] = 1
index[0, 1, 0] = 1
index[0, 1, 1] = 0
index[1, 0, 0] = 1
index[1, 0, 1] = 0
index[1, 1, 0] = 0
index[1, 1, 1] = 1
index[2, 0, 0] = 0
index[2, 0, 1] = 1
index[2, 1, 0] = 1
index[2, 1, 1] = 0

接下来,有3种情况出现,分别是dim=0、dim=1、dim=2 

dim=0

拿index里的元素值去替换对应索引中第1个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [1, 0, 1]
[0, 1, 0], 1 -> [1, 1, 0]
[0, 1, 1], 0 -> [0, 1, 1]
[1, 0, 0], 1 -> [1, 0, 0]
[1, 0, 1], 0 -> [0, 0, 1]
[1, 1, 0], 0 -> [0, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [0, 0, 0]
[2, 0, 1], 1 -> [1, 0, 1]
[2, 1, 0], 1 -> [1, 1, 0]
[2, 1, 1], 0 -> [0, 1, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[1, 0, 1]],[x[1, 1, 0], x[0, 1, 1]],[[x[1, 0, 0], x[0, 0, 1],[x[0, 1, 0], x[1, 1, 1]],[[x[0, 0, 0], x[1, 0, 1], [x[1, 1, 0], x[0, 1, 1]]]]=[[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]])
"""

dim=1

拿index里的元素值去替换对应索引中第2个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 1, 1]
[0, 1, 0], 1 -> [0, 1, 0]
[0, 1, 1], 0 -> [0, 0, 1]
[1, 0, 0], 1 -> [1, 1, 0]
[1, 0, 1], 0 -> [1, 0, 1]
[1, 1, 0], 0 -> [1, 0, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 1, 1]
[2, 1, 0], 1 -> [2, 1, 0]
[2, 1, 1], 0 -> [2, 0, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 1, 1]],[x[0, 1, 0], x[0, 0, 1]],[[x[1, 1, 0], x[1, 0, 1],[x[1, 0, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 1, 1], [x[2, 1, 0], x[2, 0, 1]]]]=[[[1, 4],[3, 2]],[[7, 6],[5, 8]],[[9, 12],[11, 10]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  4],[ 3,  2]],[[ 7,  6],[ 5,  8]],[[ 9, 12],[11, 10]]])
"""

dim=3

拿index里的元素值去替换对应索引中第3个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 0, 1]
[0, 1, 0], 1 -> [0, 1, 1]
[0, 1, 1], 0 -> [0, 1, 0]
[1, 0, 0], 1 -> [1, 0, 1]
[1, 0, 1], 0 -> [1, 0, 0]
[1, 1, 0], 0 -> [1, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 0, 1]
[2, 1, 0], 1 -> [2, 1, 1]
[2, 1, 1], 0 -> [2, 1, 0]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 0, 1]],[x[0, 1, 1], x[0, 1, 0]],[[x[1, 0, 1], x[1, 0, 0],[x[1, 1, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 0, 1], [x[2, 1, 1], x[2, 1, 0]]]]=[[[1, 2],[4, 3]],[[6, 5],[7, 8]],[[9, 10],[12, 11]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  2],[ 4,  3]],[[ 6,  5],[ 7,  8]],[[ 9, 10],[12, 11]]])
"""

这篇关于torch.gather用法详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921297

相关文章

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Java访问修饰符public、private、protected及默认访问权限详解

《Java访问修饰符public、private、protected及默认访问权限详解》:本文主要介绍Java访问修饰符public、private、protected及默认访问权限的相关资料,每... 目录前言1. public 访问修饰符特点:示例:适用场景:2. private 访问修饰符特点:示例:

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

详解Java如何向http/https接口发出请求

《详解Java如何向http/https接口发出请求》这篇文章主要为大家详细介绍了Java如何实现向http/https接口发出请求,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 用Java发送web请求所用到的包都在java.net下,在具体使用时可以用如下代码,你可以把它封装成一

JAVA系统中Spring Boot应用程序的配置文件application.yml使用详解

《JAVA系统中SpringBoot应用程序的配置文件application.yml使用详解》:本文主要介绍JAVA系统中SpringBoot应用程序的配置文件application.yml的... 目录文件路径文件内容解释1. Server 配置2. Spring 配置3. Logging 配置4. Ma

mac中资源库在哪? macOS资源库文件夹详解

《mac中资源库在哪?macOS资源库文件夹详解》经常使用Mac电脑的用户会发现,找不到Mac电脑的资源库,我们怎么打开资源库并使用呢?下面我们就来看看macOS资源库文件夹详解... 在 MACOS 系统中,「资源库」文件夹是用来存放操作系统和 App 设置的核心位置。虽然平时我们很少直接跟它打交道,但了

关于Maven中pom.xml文件配置详解

《关于Maven中pom.xml文件配置详解》pom.xml是Maven项目的核心配置文件,它描述了项目的结构、依赖关系、构建配置等信息,通过合理配置pom.xml,可以提高项目的可维护性和构建效率... 目录1. POM文件的基本结构1.1 项目基本信息2. 项目属性2.1 引用属性3. 项目依赖4. 构

Rust 数据类型详解

《Rust数据类型详解》本文介绍了Rust编程语言中的标量类型和复合类型,标量类型包括整数、浮点数、布尔和字符,而复合类型则包括元组和数组,标量类型用于表示单个值,具有不同的表示和范围,本文介绍的非... 目录一、标量类型(Scalar Types)1. 整数类型(Integer Types)1.1 整数字

Java操作ElasticSearch的实例详解

《Java操作ElasticSearch的实例详解》Elasticsearch是一个分布式的搜索和分析引擎,广泛用于全文搜索、日志分析等场景,本文将介绍如何在Java应用中使用Elastics... 目录简介环境准备1. 安装 Elasticsearch2. 添加依赖连接 Elasticsearch1. 创

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1