pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法

本文主要是介绍pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数组排序并返回前N值

对数组的第n个维度进项排序,并返回排序的前k个元素的values, indices

torch.topk(input, k, dim=n, largest=True, sorted=True, out=None) 
-> (Tensor, LongTensor)

例:取input的第1维

values, indices = torch.topk(input , 1, dim=1)

l a r g e s t { T r u e ,按照大到小排序 F a l s e ,按照小到大排序 largest\left\{\begin{array}{l}True,\mathrm{按照大到小排序}\\False,\mathrm{按照小到大排序}\end{array}\right. largest{True按照大到小排序False按照小到大排序

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

sorted:返回的结果按照顺序返回

out:可缺省,不要

按照索引取值:torch.gather(input,dim,index),或indicat_select

import torch
input = [[2, 3, 4, 5, 0, 0],[1, 4, 3, 0, 0, 0],[4, 2, 2, 5, 7, 0],[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
index = torch.LongTensor([[3],[2],[4],[0]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, index)
————————————————
版权声明:https://blog.csdn.net/cpluss/article/details/90260550 https://www.zhihu.com/question/374472015

截取Tensor

初始方法

```c
res1 = []
for i in range(10):res1.append(i*3)
res = out[:, res1]
## [narrow](https://pytorch.org/docs/stable/generated/torch.narrow.html?highlight=narrow#torch.narrow)维度范围返回 。返回的张量和张量共享相同的底层存储。
Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。```c
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)# 沿着x的第0维度,的第0位置开始,向下选取2个距离
tensor([[ 1,  2,  3],[ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)# 沿着x的第1维度,的第1位置开始,向下选取2个距离
tensor([[ 2,  3],[ 5,  6],[ 8,  9]])

mask方式选取torch.masked_select(input,mask)

>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],[-2.0099,  0.6249, -0.5382,  1.4458],[ 0.0684,  0.4118,  0.1011, -0.5684]])>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)tensor([[ True,  True, False, False],[False,  True, False,  True],[False, False, False, False]])>>> print(torch.masked_select(x, mask))tensor([1.2001, 1.2968, 0.6249, 1.4458])————————————————
版权声明  https://pytorch.org/docs/stable/generated/torch.masked_select.html#torch.masked_select   https://cloud.tencent.com/developer/article/1755706

permute置换操作+res = input[:,0:N]

where()

>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],[0.0000, 0.0000]], dtype=torch.float64)



这篇关于pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1894

相关文章

Spring Boot @RestControllerAdvice全局异常处理最佳实践

《SpringBoot@RestControllerAdvice全局异常处理最佳实践》本文详解SpringBoot中通过@RestControllerAdvice实现全局异常处理,强调代码复用、统... 目录前言一、为什么要使用全局异常处理?二、核心注解解析1. @RestControllerAdvice2

golang中reflect包的常用方法

《golang中reflect包的常用方法》Go反射reflect包提供类型和值方法,用于获取类型信息、访问字段、调用方法等,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值... 目录reflect包方法总结类型 (Type) 方法值 (Value) 方法reflect包方法总结

C# 比较两个list 之间元素差异的常用方法

《C#比较两个list之间元素差异的常用方法》:本文主要介绍C#比较两个list之间元素差异,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. 使用Except方法2. 使用Except的逆操作3. 使用LINQ的Join,GroupJoin

MySQL查询JSON数组字段包含特定字符串的方法

《MySQL查询JSON数组字段包含特定字符串的方法》在MySQL数据库中,当某个字段存储的是JSON数组,需要查询数组中包含特定字符串的记录时传统的LIKE语句无法直接使用,下面小编就为大家介绍两种... 目录问题背景解决方案对比1. 精确匹配方案(推荐)2. 模糊匹配方案参数化查询示例使用场景建议性能优

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

一文详解Git中分支本地和远程删除的方法

《一文详解Git中分支本地和远程删除的方法》在使用Git进行版本控制的过程中,我们会创建多个分支来进行不同功能的开发,这就容易涉及到如何正确地删除本地分支和远程分支,下面我们就来看看相关的实现方法吧... 目录技术背景实现步骤删除本地分支删除远程www.chinasem.cn分支同步删除信息到其他机器示例步骤

MySQL 获取字符串长度及注意事项

《MySQL获取字符串长度及注意事项》本文通过实例代码给大家介绍MySQL获取字符串长度及注意事项,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录mysql 获取字符串长度详解 核心长度函数对比⚠️ 六大关键注意事项1. 字符编码决定字节长度2

在Golang中实现定时任务的几种高效方法

《在Golang中实现定时任务的几种高效方法》本文将详细介绍在Golang中实现定时任务的几种高效方法,包括time包中的Ticker和Timer、第三方库cron的使用,以及基于channel和go... 目录背景介绍目的和范围预期读者文档结构概述术语表核心概念与联系故事引入核心概念解释核心概念之间的关系

在Linux终端中统计非二进制文件行数的实现方法

《在Linux终端中统计非二进制文件行数的实现方法》在Linux系统中,有时需要统计非二进制文件(如CSV、TXT文件)的行数,而不希望手动打开文件进行查看,例如,在处理大型日志文件、数据文件时,了解... 目录在linux终端中统计非二进制文件的行数技术背景实现步骤1. 使用wc命令2. 使用grep命令