torch.gather用法详解

2024-04-20 20:52
文章标签 详解 用法 torch gather

本文主要是介绍torch.gather用法详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.gather是PyTorch中的一个函数,用于从源张量中按照指定的索引张量来收集数据。

基本语法如下,

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • input:输入源张量
  • dim:要收集数据的维度
  • index:索引
  • sparse_grad:如果为True,则gather()在反向传播时会返回稀疏梯度
  • out:输出张量,形状与index相同

用法讲解

假设有以下输入张量x,

x = torch.tensor([[[ 1,  2],[ 3,  4]],[[ 5,  6],[ 7,  8]],[[ 9, 10],[11, 12]]
])

假设有以下索引index,

index = torch.tensor([[[0, 1],[1, 0]],[[1, 0],[0, 1]],[[0, 1],[1, 0]]
])

index的索引及里面的元素的对应关系如下,

index[0, 0, 0] = 0
index[0, 0, 1] = 1
index[0, 1, 0] = 1
index[0, 1, 1] = 0
index[1, 0, 0] = 1
index[1, 0, 1] = 0
index[1, 1, 0] = 0
index[1, 1, 1] = 1
index[2, 0, 0] = 0
index[2, 0, 1] = 1
index[2, 1, 0] = 1
index[2, 1, 1] = 0

接下来,有3种情况出现,分别是dim=0、dim=1、dim=2 

dim=0

拿index里的元素值去替换对应索引中第1个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [1, 0, 1]
[0, 1, 0], 1 -> [1, 1, 0]
[0, 1, 1], 0 -> [0, 1, 1]
[1, 0, 0], 1 -> [1, 0, 0]
[1, 0, 1], 0 -> [0, 0, 1]
[1, 1, 0], 0 -> [0, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [0, 0, 0]
[2, 0, 1], 1 -> [1, 0, 1]
[2, 1, 0], 1 -> [1, 1, 0]
[2, 1, 1], 0 -> [0, 1, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[1, 0, 1]],[x[1, 1, 0], x[0, 1, 1]],[[x[1, 0, 0], x[0, 0, 1],[x[0, 1, 0], x[1, 1, 1]],[[x[0, 0, 0], x[1, 0, 1], [x[1, 1, 0], x[0, 1, 1]]]]=[[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]])
"""

dim=1

拿index里的元素值去替换对应索引中第2个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 1, 1]
[0, 1, 0], 1 -> [0, 1, 0]
[0, 1, 1], 0 -> [0, 0, 1]
[1, 0, 0], 1 -> [1, 1, 0]
[1, 0, 1], 0 -> [1, 0, 1]
[1, 1, 0], 0 -> [1, 0, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 1, 1]
[2, 1, 0], 1 -> [2, 1, 0]
[2, 1, 1], 0 -> [2, 0, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 1, 1]],[x[0, 1, 0], x[0, 0, 1]],[[x[1, 1, 0], x[1, 0, 1],[x[1, 0, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 1, 1], [x[2, 1, 0], x[2, 0, 1]]]]=[[[1, 4],[3, 2]],[[7, 6],[5, 8]],[[9, 12],[11, 10]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  4],[ 3,  2]],[[ 7,  6],[ 5,  8]],[[ 9, 12],[11, 10]]])
"""

dim=3

拿index里的元素值去替换对应索引中第3个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 0, 1]
[0, 1, 0], 1 -> [0, 1, 1]
[0, 1, 1], 0 -> [0, 1, 0]
[1, 0, 0], 1 -> [1, 0, 1]
[1, 0, 1], 0 -> [1, 0, 0]
[1, 1, 0], 0 -> [1, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 0, 1]
[2, 1, 0], 1 -> [2, 1, 1]
[2, 1, 1], 0 -> [2, 1, 0]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 0, 1]],[x[0, 1, 1], x[0, 1, 0]],[[x[1, 0, 1], x[1, 0, 0],[x[1, 1, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 0, 1], [x[2, 1, 1], x[2, 1, 0]]]]=[[[1, 2],[4, 3]],[[6, 5],[7, 8]],[[9, 10],[12, 11]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  2],[ 4,  3]],[[ 6,  5],[ 7,  8]],[[ 9, 10],[12, 11]]])
"""

这篇关于torch.gather用法详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921297

相关文章

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

java中反射Reflection的4个作用详解

《java中反射Reflection的4个作用详解》反射Reflection是Java等编程语言中的一个重要特性,它允许程序在运行时进行自我检查和对内部成员(如字段、方法、类等)的操作,本文将详细介绍... 目录作用1、在运行时判断任意一个对象所属的类作用2、在运行时构造任意一个类的对象作用3、在运行时判断

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

MyBatis-Plus 中 nested() 与 and() 方法详解(最佳实践场景)

《MyBatis-Plus中nested()与and()方法详解(最佳实践场景)》在MyBatis-Plus的条件构造器中,nested()和and()都是用于构建复杂查询条件的关键方法,但... 目录MyBATis-Plus 中nested()与and()方法详解一、核心区别对比二、方法详解1.and()

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客