pytorch笔记:PackedSequence对象送入RNN

2023-11-02 05:15

本文主要是介绍pytorch笔记:PackedSequence对象送入RNN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch 笔记:PAD_PACKED_SEQUENCE 和PACK_PADDED_SEQUENCE-CSDN博客 

  • 当使用pack_padded_sequence得到一个PackedSequence对象并将其送入RNN(如LSTM或GRU)时,RNN内部会进行特定的操作来处理这种特殊的输入形式。
  • 使用PackedSequence的主要好处是提高效率和计算速度。因为通过跳过填充部分,RNN不需要在这些部分进行无用的计算。这特别对于处理长度差异很大的批量序列时很有帮助。

1 PackedSequence对象

  • PackedSequence是一个命名元组,其中主要的两个属性是databatch_sizes
    • data是一个1D张量,包含所有非零长度序列的元素,按照其在批次中的顺序排列。
    • batch_sizes是一个1D张量,表示每个时间步的批次大小
  • PackedSequence(data=tensor([6, 5, 1, 8, 7, 9]),batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)

 2 处理PackedSequence

  • 当RNN遇到PackedSequence作为输入时,它会按照batch_sizes中指定的方式对data进行迭代
  • 举例来说,上面例子中batch_sizes[3,2,1],那么RNN首先处理前3个元素,然后是接下来的2个元素,最后是最后一个元素。
  • 这允许RNN仅处理有效的序列部分,而跳过填充

3 输出

  • 当RNN完成对PackedSequence的处理后,它的输出同样是一个PackedSequence对象
  • 可以使用pad_packed_sequence将其转换回常规的填充张量格式,以进行后续操作或损失计算
  • 隐藏状态和单元状态(对于LSTM)也会被返回,这些状态与未打包的序列的处理方式相同

4  举例

  • 假设我们有以下3个句子,我们想要用RNN进行处理:
I love AI
Hello
PyTorch is great
  • 为了送入RNN,我们首先需要将这些句子转换为整数形式,并进行填充以保证它们在同一个批次中有相同的长度。
{'PAD': 0,'I': 1,'love': 2,'AI': 3,'Hello': 4,'PyTorch': 5,'is': 6,'great': 7
}
  • 句子转换为整数后(id):
  1. I love AI -> [1, 2, 3]
  2. Hello -> [4]
  3. PyTorch is great -> [5, 6, 7]
  • 为了将它们放入同一个批次,我们进行填充:
[1, 2, 3]
[4, 0, 0]
[5, 6, 7]
  • 假设每个单词的id 对应的embedding就是自己:
[[1], [2], [3]]
[[4], [0], [0]]
[[5], [6], [7]]
  • 使用pack_padded_sequence进行处理
import torch
from torch.nn.utils.rnn import pack_padded_sequence# 输入序列
input_seq = torch.tensor([[1,2,3], [4, 0, 0], [5,6,7]])
input_seq=input_seq.reshape(data.shape[0],input_seq.shape[1],1)
#每个单词id的embedding就是他自己
input_seq=input_seq.float()
#变成float是为了喂入RNN所需# 序列的实际长度
lengths = [3, 1, 3]# 使用pack_padded_sequence
packed = pack_padded_sequence(input_seq, lengths, batch_first=True,enforce_sorted=False)packed
'''
PackedSequence(data=tensor([[1.],[5.],[4.],[2.],[6.],[3.],[7.]]), batch_sizes=tensor([3, 2, 2]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
'''
  • 现在,当我们将此PackedSequence送入RNN时,RNN首先处理前3个元素,因为batch_sizes的第一个元素是3。然后,它处理接下来的2个元素,最后处理剩下的2个元素。
    • 具体来说,RNN会如下处理:

      • 时间步1:根据batch_sizes[0] = 3,RNN同时处理三个句子的第一个元素。具体地说,它处理句子1的"I",句子2的"PyTorch",和句子3的"Hello"。
      • 时间步2:根据batch_sizes[1] = 2,RNN处理接下来两个句子的第二个元素,即句子1的"love"和句子2的"is"。
      • 时间步3:根据batch_sizes[2] = 2,RNN处理接下来两个句子的第三个元素,即句子1的"AI"和句子2的"great"。
  • 喂入RNN
import torch.nn as nnclass SimpleRNN(nn.Module):def __init__(self,input_size,hidden_size,num_layer=1):super(SimpleRNN,self).__init__()self.rnn=nn.RNN(input_size,hidden_size,num_layer,batch_first=True)def forward(self,x,hidden=None):packed_output,h_n=self.rnn(x,hidden)return packed_output,h_n
#单层的RNNSrnn=SimpleRNN(1,3)
Srnn(packed_data) 
'''
(PackedSequence(data=tensor([[-0.1207, -0.0247,  0.4188],[-0.3173, -0.0499,  0.6838],[-0.4900, -0.0751,  0.8415],[-0.7051, -0.1611,  0.9610],[-0.7497, -0.2117,  0.9829],[-0.3361, -0.1660,  0.9329],[ 0.4608, -0.0492,  0.1138]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 2]), sorted_indices=None, unsorted_indices=None),tensor([[[-0.3361, -0.1660,  0.9329],[ 0.4608, -0.0492,  0.1138],[-0.4900, -0.0751,  0.8415]]], grad_fn=<StackBackward0>))
'''
  • 得到的RNN输出是pack的,hidden state没有变化
    • Srnn=SimpleRNN(1,3)
      Srnn(packed_data) 
      '''
      (PackedSequence(data=tensor([[-0.1207, -0.0247,  0.4188],[-0.3173, -0.0499,  0.6838],[-0.4900, -0.0751,  0.8415],[-0.7051, -0.1611,  0.9610],[-0.7497, -0.2117,  0.9829],[-0.3361, -0.1660,  0.9329],[ 0.4608, -0.0492,  0.1138]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 2]), sorted_indices=None, unsorted_indices=None),tensor([[[-0.3361, -0.1660,  0.9329],[ 0.4608, -0.0492,  0.1138],[-0.4900, -0.0751,  0.8415]]], grad_fn=<StackBackward0>))
      '''pad_packed_sequence(Srnn(packed_data)[0],batch_first=True)
      '''
      (tensor([[[-0.1207, -0.0247,  0.4188],[-0.7051, -0.1611,  0.9610],[-0.3361, -0.1660,  0.9329]],[[-0.3173, -0.0499,  0.6838],[-0.7497, -0.2117,  0.9829],[ 0.4608, -0.0492,  0.1138]],[[-0.4900, -0.0751,  0.8415],[ 0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000]]], grad_fn=<TransposeBackward0>),tensor([3, 3, 1]))
      '''

这篇关于pytorch笔记:PackedSequence对象送入RNN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/328371

相关文章

JSON字符串转成java的Map对象详细步骤

《JSON字符串转成java的Map对象详细步骤》:本文主要介绍如何将JSON字符串转换为Java对象的步骤,包括定义Element类、使用Jackson库解析JSON和添加依赖,文中通过代码介绍... 目录步骤 1: 定义 Element 类步骤 2: 使用 Jackson 库解析 jsON步骤 3: 添

Spring常见错误之Web嵌套对象校验失效解决办法

《Spring常见错误之Web嵌套对象校验失效解决办法》:本文主要介绍Spring常见错误之Web嵌套对象校验失效解决的相关资料,通过在Phone对象上添加@Valid注解,问题得以解决,需要的朋... 目录问题复现案例解析问题修正总结  问题复现当开发一个学籍管理系统时,我们会提供了一个 API 接口去

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Java如何通过反射机制获取数据类对象的属性及方法

《Java如何通过反射机制获取数据类对象的属性及方法》文章介绍了如何使用Java反射机制获取类对象的所有属性及其对应的get、set方法,以及如何通过反射机制实现类对象的实例化,感兴趣的朋友跟随小编一... 目录一、通过反射机制获取类对象的所有属性以及相应的get、set方法1.遍历类对象的所有属性2.获取

java中VO PO DTO POJO BO DO对象的应用场景及使用方式

《java中VOPODTOPOJOBODO对象的应用场景及使用方式》文章介绍了Java开发中常用的几种对象类型及其应用场景,包括VO、PO、DTO、POJO、BO和DO等,并通过示例说明了它... 目录Java中VO PO DTO POJO BO DO对象的应用VO (View Object) - 视图对象

vue如何监听对象或者数组某个属性的变化详解

《vue如何监听对象或者数组某个属性的变化详解》这篇文章主要给大家介绍了关于vue如何监听对象或者数组某个属性的变化,在Vue.js中可以通过watch监听属性变化并动态修改其他属性的值,watch通... 目录前言用watch监听深度监听使用计算属性watch和计算属性的区别在vue 3中使用watchE

Java将时间戳转换为Date对象的方法小结

《Java将时间戳转换为Date对象的方法小结》在Java编程中,处理日期和时间是一个常见需求,特别是在处理网络通信或者数据库操作时,本文主要为大家整理了Java中将时间戳转换为Date对象的方法... 目录1. 理解时间戳2. Date 类的构造函数3. 转换示例4. 处理可能的异常5. 考虑时区问题6.

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear