pytorch中 nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence

2023-10-17 22:50

本文主要是介绍pytorch中 nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 官方文档:

torch.nn — PyTorch 1.11.0 documentation

 

2. 应用背景:

在使用pytorch处理数据时,一般是采用batch的形式同时处理多个样本序列,而每个batch中的样本序列是不等长的,导致rnn无法处理。所以,通常的做法是先将每个batch按照最长的序列进行padding处理等长的形式。

但padding操作会带来一个问题,那就是对于多数进行padding过的序列,会导致rnn对它的表示多了很多无用的字符,我们希望的是在最后一个有用的字符后就可以输出该序列的向量表示,而不是在很多padding字符后。

这时候,pack操作就派上场了,可以理解成,它是将一个经过padding后的变长序列压紧,压缩后就不含padding的字符0了。具体操作就是:

  • 第一步,padding后的输入序列先经过nn.utils.rnn.pack_padded_sequence,这样会得到一个PackedSequence类型的object,可以直接传给RNN(RNN的源码中的forward函数里上来就是判断输入是否是PackedSequence的实例,进而采取不同的操作,如果是则输出也是该类型。)
  • 第二步,得到的PackedSequence类型的object,正常直接传给RNN,得到的同样是该类型的输出
  • 第三步,再经过nn.utils.rnn.pad_packed_sequence,也就是对经过RNN后的输出重新进行padding操作,得到正常的每个batch等长的序列。

3. 函数详解:

3.1 nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence — PyTorch 1.11.0 documentation

3.2 nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence — PyTorch 1.11.0 documentation

4. 代码实例:

4.1 使用时:

import torch
import torch.nn as nngru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)input = torch.tensor([[1,2,3,4,5],[1,2,3,4,0],[1,2,3,0,0],[1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)print(output)

 

4.2 不使用时:

import torch
import torch.nn as nngru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)input = torch.tensor([[1,2,3,4,5],[1,2,3,4,0],[1,2,3,0,0],[1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
# input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
# output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)print(output)

 

5. 注意几个参数:

5.1 batch_first

包括RNN中,参数默认为 False,也就是它鼓励输入的第一维不是batch,这与我们常规输入相悖,我们习惯的输入为(batch_size, seq_len, embedding_dim),所以需要注意,要么该数据输入,要不该参数设置为True。

5.2 enforce_sorted

参数默认为 True,也就是它默认batch中每个序列已经按照长度由大到小进行排列了,所以需要注意,如果没有进行排序,则改为False。

 部分参考:https://www.cnblogs.com/yuqinyuqin/p/14100967.html 

这篇关于pytorch中 nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/228476

相关文章

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

torch.nn 与 torch.nn.functional的区别?

区别 PyTorch中torch.nn与torch.nn.functional的区别是:1.继承方式不同;2.可训练参数不同;3.实现方式不同;4.调用方式不同。 1.继承方式不同 torch.nn 中的模块大多数是通过继承torch.nn.Module 类来实现的,这些模块都是Python 类,需要进行实例化才能使用。而torch.nn.functional 中的函数是直接调用的,无需

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

浙大数据结构:02-线性结构4 Pop Sequence

这道题我们采用数组来模拟堆栈和队列。 简单说一下大致思路,我们用栈来存1234.....,队列来存输入的一组数据,栈与队列进行匹配,相同就pop 机翻 1、条件准备 stk是栈,que是队列。 tt指向的是栈中下标,front指向队头,rear指向队尾。 初始化栈顶为0,队头为0,队尾为-1 #include<iostream>using namespace std;#defi

【UVA】1626-Brackets sequence(动态规划)

一道算是比较难理解的动规。 状态转移分2个: (用d[i][j]表示在i~j内最少需要添加几个括号,保持平衡) 1.如果s[i]和s[j]是一对括号,那么d[i][j] = d[i + 1][j - 1] 2.否则的话 d[i][j] = min(d[i][k],[k + 1][j]); 边界是d[i + 1][i] = 0; d[i][i] = 1; 13993644 162

【UVA】10534 - Wavio Sequence(LIS最长上升子序列)

这题一看10000的数据量就知道必须用nlog(n)的时间复杂度。 所以特意去看了最长上升子序列的nlog(n)的算法。 如果有2个位置,该位置上的元素为A[i]和A[j],并且他们满足以下条件: 1.dp[i] = dp[j]    (dp[x]代表以x结尾的最长上升子序列长度) 2.A[i] < A[j] 3.i < j 那么毫无疑问,选择dp[i] 一定优于选择dp[j] 那么

2015年多校联合训练第一场OO’s Sequence(hdu5288)

题意:给定一个长度为n的序列,规定f(l,r)是对于l,r范围内的某个数字a[i],都不能找到一个对应的j使得a[i]%a[j]=0,那么l,r内有多少个i,f(l,r)就是几。问所有f(l,r)的总和是多少。 公式中给出的区间,也就是所有存在的区间。 思路:直接枚举每一个数字,对于这个数字,如果这个数字是合法的i,那么向左能扩展的最大长度是多少,向右能扩展的最大长度是多少,那么i为合法的情况