Pytorch函数——torch.gather详解

2024-01-15 21:20

本文主要是介绍Pytorch函数——torch.gather详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在学习强化学习时,顺便复习复习pytorch的基本内容,遇到了 torch.gather()函数,参考图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)进行解释。

pytorch官网对函数给出的解释:

image.png

即input是一个矩阵,根据dim的值,将index的值替换到不同的维度的索引,当dim为0时,index替代i的值,成为第0维度的索引。

输入和输出的矩阵形式相同。

例子:首先我们生成3×3的矩阵,明确行索引的概念,第0行指的是[3,4,5],第0列指的是[[3] [6] [9]]

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

index为行向量且dim=0时

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[9, 7, 5]])

当dim=0时,替换第0维度。由于input为二维列表,因此第0维度指的是选择第几行的维度,即行索引所在的维度,替换了i的索引,为input[index[i][j]] [j]

那么我们会输出tensor([[ input[2][j] input[1][j] input[0][j] ]]),那么j如何获得呢?从index of index中拿到,index每一个元素的索引为(0,0) (0,1) (0,2),取j,则为0,1,2,那么输出则为tensor([[ input[2][0] input[1][1] input[0][2] ]]),即

tensor([[9, 7, 5]])

输入行向量index,且dim=1

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5, 4, 3]])

维度为1,则替换列索引的值,那么输出为tensor([[ input[i][2] input[i][1] input[i][0] ]]),index每一个元素的索引为(0,0) (0,1) (0,2),i均为1,那么tensor([[ input[0][2] input[0][1] input[0][0] ]])

输入为行向量,dim=0

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[5],[7],[9]])

维度为0,则替换行索引,且输出与输入的格式相同,为

tensor([input[2][j],input[1][j],input[0][j]])

index每一个元素的索引为(0,0) (1,0) (2,0),j对应的值为0,0,0,则

tensor([input[2][0],input[1][0],input[0][0]])

输入为行向量,dim=1

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],[7],[9]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([input[i][2],input[i][1],input[i][0]])

index每一个元素的索引为(0,0) (1,0) (2,0),i对应的值为0,1,2,则

tensor([input[0][2],input[1][1],input[2][0]])

输入为二维矩阵,且dim=1

index = torch.tensor([[0, 2],[1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],[7, 8]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([[input[i][0], input[i][2]],[input[i][1], input[i][2]]])

替换为行索引后,可得:

tensor([[input[0][0], input[0][2]],[input[1][1], input[1][2]]])

在强化学习中的应用

在PyTorch官网DQN页面的代码中,i是state,j是a

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

我们使用dim=1action_batch 将获得的动作列表替换为列索引,即可获得每个state下该动作的动作价值。

这篇关于Pytorch函数——torch.gather详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/610246

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

嵌入式Openharmony系统构建与启动详解

大家好,今天主要给大家分享一下,如何构建Openharmony子系统以及系统的启动过程分解。 第一:OpenHarmony系统构建      首先熟悉一下,构建系统是一种自动化处理工具的集合,通过将源代码文件进行一系列处理,最终生成和用户可以使用的目标文件。这里的目标文件包括静态链接库文件、动态链接库文件、可执行文件、脚本文件、配置文件等。      我们在编写hellowor

LabVIEW FIFO详解

在LabVIEW的FPGA开发中,FIFO(先入先出队列)是常用的数据传输机制。通过配置FIFO的属性,工程师可以在FPGA和主机之间,或不同FPGA VIs之间进行高效的数据传输。根据具体需求,FIFO有多种类型与实现方式,包括目标范围内FIFO(Target-Scoped)、DMA FIFO以及点对点流(Peer-to-Peer)。 FIFO类型 **目标范围FIFO(Target-Sc

019、JOptionPane类的常用静态方法详解

目录 JOptionPane类的常用静态方法详解 1. showInputDialog()方法 1.1基本用法 1.2带有默认值的输入框 1.3带有选项的输入对话框 1.4自定义图标的输入对话框 2. showConfirmDialog()方法 2.1基本用法 2.2自定义按钮和图标 2.3带有自定义组件的确认对话框 3. showMessageDialog()方法 3.1

函数式编程思想

我们经常会用到各种各样的编程思想,例如面向过程、面向对象。不过笔者在该博客简单介绍一下函数式编程思想. 如果对函数式编程思想进行概括,就是f(x) = na(x) , y=uf(x)…至于其他的编程思想,可能是y=a(x)+b(x)+c(x)…,也有可能是y=f(x)=f(x)/a + f(x)/b+f(x)/c… 面向过程的指令式编程 面向过程,简单理解就是y=a(x)+b(x)+c(x)