pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法

本文主要是介绍pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数组排序并返回前N值

对数组的第n个维度进项排序,并返回排序的前k个元素的values, indices

torch.topk(input, k, dim=n, largest=True, sorted=True, out=None) 
-> (Tensor, LongTensor)

例:取input的第1维

values, indices = torch.topk(input , 1, dim=1)

l a r g e s t { T r u e ,按照大到小排序 F a l s e ,按照小到大排序 largest\left\{\begin{array}{l}True,\mathrm{按照大到小排序}\\False,\mathrm{按照小到大排序}\end{array}\right. largest{True按照大到小排序False按照小到大排序

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

sorted:返回的结果按照顺序返回

out:可缺省,不要

按照索引取值:torch.gather(input,dim,index),或indicat_select

import torch
input = [[2, 3, 4, 5, 0, 0],[1, 4, 3, 0, 0, 0],[4, 2, 2, 5, 7, 0],[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
index = torch.LongTensor([[3],[2],[4],[0]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, index)
————————————————
版权声明:https://blog.csdn.net/cpluss/article/details/90260550 https://www.zhihu.com/question/374472015

截取Tensor

初始方法

```c
res1 = []
for i in range(10):res1.append(i*3)
res = out[:, res1]
## [narrow](https://pytorch.org/docs/stable/generated/torch.narrow.html?highlight=narrow#torch.narrow)维度范围返回 。返回的张量和张量共享相同的底层存储。
Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。```c
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)# 沿着x的第0维度,的第0位置开始,向下选取2个距离
tensor([[ 1,  2,  3],[ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)# 沿着x的第1维度,的第1位置开始,向下选取2个距离
tensor([[ 2,  3],[ 5,  6],[ 8,  9]])

mask方式选取torch.masked_select(input,mask)

>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],[-2.0099,  0.6249, -0.5382,  1.4458],[ 0.0684,  0.4118,  0.1011, -0.5684]])>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)tensor([[ True,  True, False, False],[False,  True, False,  True],[False, False, False, False]])>>> print(torch.masked_select(x, mask))tensor([1.2001, 1.2968, 0.6249, 1.4458])————————————————
版权声明  https://pytorch.org/docs/stable/generated/torch.masked_select.html#torch.masked_select   https://cloud.tencent.com/developer/article/1755706

permute置换操作+res = input[:,0:N]

where()

>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],[0.0000, 0.0000]], dtype=torch.float64)



这篇关于pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1894

相关文章

Java function函数式接口的使用方法与实例

《Javafunction函数式接口的使用方法与实例》:本文主要介绍Javafunction函数式接口的使用方法与实例,函数式接口如一支未完成的诗篇,用Lambda表达式作韵脚,将代码的机械美感... 目录引言-当代码遇见诗性一、函数式接口的生物学解构1.1 函数式接口的基因密码1.2 六大核心接口的形态学

Python实现文件下载、Cookie以及重定向的方法代码

《Python实现文件下载、Cookie以及重定向的方法代码》本文主要介绍了如何使用Python的requests模块进行网络请求操作,涵盖了从文件下载、Cookie处理到重定向与历史请求等多个方面,... 目录前言一、下载网络文件(一)基本步骤(二)分段下载大文件(三)常见问题二、requests模块处理

Linux内存泄露的原因排查和解决方案(内存管理方法)

《Linux内存泄露的原因排查和解决方案(内存管理方法)》文章主要介绍了运维团队在Linux处理LB服务内存暴涨、内存报警问题的过程,从发现问题、排查原因到制定解决方案,并从中学习了Linux内存管理... 目录一、问题二、排查过程三、解决方案四、内存管理方法1)linux内存寻址2)Linux分页机制3)

vue基于ElementUI动态设置表格高度的3种方法

《vue基于ElementUI动态设置表格高度的3种方法》ElementUI+vue动态设置表格高度的几种方法,抛砖引玉,还有其它方法动态设置表格高度,大家可以开动脑筋... 方法一、css + js的形式这个方法需要在表格外层设置一个div,原理是将表格的高度设置成外层div的高度,所以外层的div需要

5分钟获取deepseek api并搭建简易问答应用

《5分钟获取deepseekapi并搭建简易问答应用》本文主要介绍了5分钟获取deepseekapi并搭建简易问答应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需... 目录1、获取api2、获取base_url和chat_model3、配置模型参数方法一:终端中临时将加

Java中List转Map的几种具体实现方式和特点

《Java中List转Map的几种具体实现方式和特点》:本文主要介绍几种常用的List转Map的方式,包括使用for循环遍历、Java8StreamAPI、ApacheCommonsCollect... 目录前言1、使用for循环遍历:2、Java8 Stream API:3、Apache Commons

Python判断for循环最后一次的6种方法

《Python判断for循环最后一次的6种方法》在Python中,通常我们不会直接判断for循环是否正在执行最后一次迭代,因为Python的for循环是基于可迭代对象的,它不知道也不关心迭代的内部状态... 目录1.使用enuhttp://www.chinasem.cnmerate()和len()来判断for

Java循环创建对象内存溢出的解决方法

《Java循环创建对象内存溢出的解决方法》在Java中,如果在循环中不当地创建大量对象而不及时释放内存,很容易导致内存溢出(OutOfMemoryError),所以本文给大家介绍了Java循环创建对象... 目录问题1. 解决方案2. 示例代码2.1 原始版本(可能导致内存溢出)2.2 修改后的版本问题在

四种Flutter子页面向父组件传递数据的方法介绍

《四种Flutter子页面向父组件传递数据的方法介绍》在Flutter中,如果父组件需要调用子组件的方法,可以通过常用的四种方式实现,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录方法 1:使用 GlobalKey 和 State 调用子组件方法方法 2:通过回调函数(Callb

一文详解Python中数据清洗与处理的常用方法

《一文详解Python中数据清洗与处理的常用方法》在数据处理与分析过程中,缺失值、重复值、异常值等问题是常见的挑战,本文总结了多种数据清洗与处理方法,文中的示例代码简洁易懂,有需要的小伙伴可以参考下... 目录缺失值处理重复值处理异常值处理数据类型转换文本清洗数据分组统计数据分箱数据标准化在数据处理与分析过