pytorch花式索引提取topk的张量

2024-02-14 09:52

本文主要是介绍pytorch花式索引提取topk的张量,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • pytorch花式索引提取topk的张量
    • 问题设定
    • 代码实现
      • 索引方法
      • gather方法
      • 验证
    • 补充知识
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的张量

问题设定

在这里插入图片描述
或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(bs, dim, X)的reduced向量。我们在进行topk操作(以减少计算量)的时候经常碰到这种情况。
给出如下两种实现方法,分别使用花式索引(参考informer的代码)以及pytorch的gather方法

代码实现

索引方法

参考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 结果和indices一致,说明在第二个channel上,每个样本的索引是一样的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在这里插入图片描述
在这里插入图片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

验证

两种方法得到的结果完全相同
在这里插入图片描述

补充知识

expand方法

在 PyTorch 中,expand() 方法用于扩展张量的大小。它会在不实际复制数据的情况下,重复张量的元素以填充新的形状。这个方法可以用于广播操作,以便在执行一些需要相同形状的张量之间的数学运算时,使它们具有相同的形状。

下面是使用 expand() 方法的基本用法:

import torch# 创建一个原始张量
x = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用 expand 扩展张量的大小
expanded_x = x.expand(2, 3, 4)  # 扩展成维度为(2, 3, 4)的张量print(expanded_x)

在上面的例子中,我们首先创建了一个形状为 (2, 3) 的原始张量 x。然后,我们使用 expand() 方法将其扩展成一个维度为 (2, 3, 4) 的新张量 expanded_x,该张量的形状是在原始张量形状的基础上每个维度都扩展了一倍。

需要注意的是,expand() 方法只能用于增加张量的大小,不能减小。另外,扩展后的张量与原始张量共享底层数据,因此在原始张量上进行的任何修改都会反映在扩展后的张量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于从输入张量中按照指定索引提取元素。这个方法通常用于根据索引收集特定的元素,例如根据类别索引从分类得分张量中获取对应类别的得分。

下面是使用 gather() 方法的基本用法:

import torch# 创建一个输入张量
input_tensor = torch.tensor([[1, 2],[3, 4],[5, 6]])# 创建一个索引张量
indices = torch.tensor([[0, 0],[1, 0]])# 使用 gather 方法根据索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)print(output_tensor)

在上面的例子中,我们首先创建了一个形状为 (3, 2) 的输入张量 input_tensor,以及一个形状为 (2, 2) 的索引张量 indices。然后,我们使用 gather() 方法从输入张量 input_tensor 中按照索引张量 indices 收集元素。

gather() 方法中,参数 dim 指定了在哪个维度上进行收集操作,而 index 参数指定了收集元素所使用的索引张量。

需要注意的是,索引张量 indices 的形状必须与输出张量的形状一致,或者是可以广播成与输出张量形状一致的形状。

randint

torch.randint() 是 PyTorch 中用于生成随机整数张量的函数。它可以生成一个张量,其中的元素是在指定范围内随机抽样的整数。

下面是 torch.randint() 的基本用法示例:

import torch# 生成一个形状为 (3, 3) 的随机整数张量,范围是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))print(random_integers)

在上面的示例中,我们使用了 torch.randint() 函数来生成一个形状为 (3, 3) 的随机整数张量,其中的元素取值范围在闭区间 [low, high) 内,即从 0 到 9。

torch.randint() 函数的主要参数包括:

  • low:生成的随机整数的最小值(包含)。
  • high:生成的随机整数的最大值(不包含)。
  • size:生成的张量的形状。

你也可以不指定 low 参数,默认情况下它为 0。此外,还可以使用其他参数来控制生成的随机整数张量的设备类型、数据类型等。

这篇关于pytorch花式索引提取topk的张量的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/708176

相关文章

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

Java使用正则提取字符串中的内容的详细步骤

《Java使用正则提取字符串中的内容的详细步骤》:本文主要介绍Java中使用正则表达式提取字符串内容的方法,通过Pattern和Matcher类实现,涵盖编译正则、查找匹配、分组捕获、数字与邮箱提... 目录1. 基础流程2. 关键方法说明3. 常见场景示例场景1:提取所有数字场景2:提取邮箱地址4. 高级

Python 字符串裁切与提取全面且实用的解决方案

《Python字符串裁切与提取全面且实用的解决方案》本文梳理了Python字符串处理方法,涵盖基础切片、split/partition分割、正则匹配及结构化数据解析(如BeautifulSoup、j... 目录python 字符串裁切与提取的完整指南 基础切片方法1. 使用切片操作符[start:end]2

使用Python提取PDF大纲(书签)的完整指南

《使用Python提取PDF大纲(书签)的完整指南》PDF大纲(Outline)​​是PDF文档中的导航结构,通常显示在阅读器的侧边栏中,方便用户快速跳转到文档的不同部分,大纲通常以层级结构组织,包含... 目录一、PDF大纲简介二、准备工作所需工具常见安装问题三、代码实现完整代码核心功能解析四、使用效果控

Linux从文件中提取特定内容的实用技巧分享

《Linux从文件中提取特定内容的实用技巧分享》在日常数据处理和配置文件管理中,我们经常需要从大型文件中提取特定内容,本文介绍的提取特定行技术正是这些高级操作的基础,以提取含有1的简单需求为例,我们可... 目录引言1、方法一:使用 grep 命令1.1 grep 命令基础1.2 命令详解1.3 高级用法2

MySQL 索引简介及常见的索引类型有哪些

《MySQL索引简介及常见的索引类型有哪些》MySQL索引是加速数据检索的特殊结构,用于存储列值与位置信息,常见的索引类型包括:主键索引、唯一索引、普通索引、复合索引、全文索引和空间索引等,本文介绍... 目录什么是 mysql 的索引?常见的索引类型有哪些?总结性回答详细解释1. MySQL 索引的概念2

Oracle查询表结构建表语句索引等方式

《Oracle查询表结构建表语句索引等方式》使用USER_TAB_COLUMNS查询表结构可避免系统隐藏字段(如LISTUSER的CLOB与VARCHAR2同名字段),这些字段可能为dbms_lob.... 目录oracle查询表结构建表语句索引1.用“USER_TAB_COLUMNS”查询表结构2.用“a

MySQL 强制使用特定索引的操作

《MySQL强制使用特定索引的操作》MySQL可通过FORCEINDEX、USEINDEX等语法强制查询使用特定索引,但优化器可能不采纳,需结合EXPLAIN分析执行计划,避免性能下降,注意版本差异... 目录1. 使用FORCE INDEX语法2. 使用USE INDEX语法3. 使用IGNORE IND

Python实现批量提取BLF文件时间戳

《Python实现批量提取BLF文件时间戳》BLF(BinaryLoggingFormat)作为Vector公司推出的CAN总线数据记录格式,被广泛用于存储车辆通信数据,本文将使用Python轻松提取... 目录一、为什么需要批量处理 BLF 文件二、核心代码解析:从文件遍历到数据导出1. 环境准备与依赖库