torch.gather——沿特定维度收集数值

2024-01-28 06:30

本文主要是介绍torch.gather——沿特定维度收集数值,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyTorch学习笔记:torch.gather——沿特定维度收集数值

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

功能:从输入的数组中,沿指定的dim维度,利用索引变量index,将数据索引出来,并且堆叠成一个数组。直观可能不好理解,具体可以见代码案例。

输入:

input:输入的数组

dim:指定的维度

index:索引变量,数据类型需是长整型(int64)

注意:

  • inputindex具有相同的维数

  • outindex具有相同的形状

  • 除了dim维度,在每个维度上,索引在该维度上的大小要小于等于输入在该维度上的大小,即:
    i n d e x . s i z e ( d ) ≤ i n p u t . s i z e ( d ) , d ! = d i m index.size(d)≤input.size(d),\quad d!=dim index.size(d)input.size(d),d!=dim

代码案例

一般用法,当在一个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2,3,1,3]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

在这里插入图片描述

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2],[3],[1],[3]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

在这里插入图片描述

当同时在两个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2,3],[2,3,0],[3,0,1]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 5, 11, 17],[10, 16,  2],[15,  1,  7]])

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2],[2,3],[3,4]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 1,  2],[ 7,  8],[13, 14]])

官方文档

torch.gather:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather

这篇关于torch.gather——沿特定维度收集数值的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652751

相关文章

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

JS常用组件收集

收集了一些平时遇到的前端比较优秀的组件,方便以后开发的时候查找!!! 函数工具: Lodash 页面固定: stickUp、jQuery.Pin 轮播: unslider、swiper 开关: switch 复选框: icheck 气泡: grumble 隐藏元素: Headroom

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

理解java虚拟机内存收集

学习《深入理解Java虚拟机》时个人的理解笔记 1、为什么要去了解垃圾收集和内存回收技术? 当需要排查各种内存溢出、内存泄漏问题时,当垃圾收集成为系统达到更高并发量的瓶颈时,我们就必须对这些“自动化”的技术实施必要的监控和调节。 2、“哲学三问”内存收集 what?when?how? 那些内存需要回收?什么时候回收?如何回收? 这是一个整体的问题,确定了什么状态的内存可以

Android中如何实现adb向应用发送特定指令并接收返回

1 ADB发送命令给应用 1.1 发送自定义广播给系统或应用 adb shell am broadcast 是 Android Debug Bridge (ADB) 中用于向 Android 系统发送广播的命令。通过这个命令,开发者可以发送自定义广播给系统或应用,触发应用中的广播接收器(BroadcastReceiver)。广播机制是 Android 的一种组件通信方式,应用可以监听广播来执行

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

Linux使用收集--持续更新

linux查看目录文件数》》》 查看当前目录大小: [root@xker.com]# du -sh 查看指定目录大小: [root@xker.com]# du -sh /www/xker.com 查看当前目录文件总数: [root@xker.com]# find . -type f |wc -l 查看指定目录文件总数: [root@xker.com]# fi

交换两个变量数值的3种方法

前言:交换两个数值可不是"a = b,b = a"。这样做的话,a先等于了b的值;当“b = a”后,因为此时a已经等于b的值了,这个语句就相当于执行了b = b。最终的数值关系就成了a == b,b == b。 下面教给大家3种交换变量数值的方法: 目录 1. 中介法 2. 消和法 3. 异或法 4. 总结 1. 中介法 中介法(又称 临时变量法 或 酱油法),其中心

后台开发 知识点收集

原知识点总结连接,由于有些问题比较熟悉,所以就没有在自己文章中再列出来了 计算机网络 tcp/udp区别http状态码http协议报头字段osi模型、tcp/ip模型以及各层对应的协议session机制、cookie机制tcp三次握手,四次挥手打开网页到页面显示之间的过程https和http的区别post和get的区别ip子网划分两个网络MTU不同时如何通信 数据库 常见问题mysql的两