pytorch nn.utils.rnn.pack_padded_sequence 分析

2023-10-17 22:50

本文主要是介绍pytorch nn.utils.rnn.pack_padded_sequence 分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pack_padded_sequence

在nlp模型的forward方法中,可能有以下调用令读者疑惑

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)

为什么要使用pack_padded_sequence?

参考

  • Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()
  • Pytorch中pack_padded_sequence和pad_packed_sequence的理解

当我们训练RNN时,如果想要进行批次化训练,由于句子的长短不一,所以需要截断和填充。

  • 为什么要截断?对于那些太长的句子,一般选择一个合适的长度来进行截断。
  • 为什么要填充?对于那些太短的句子,需要以 填充字符(比如<pad>)填充,使得该批次内所有的句子长度相同。

但是,填充会带来其它问题:

  • 增加了计算复杂度。假设一个批次内有2个句子,长度分别为5和2。我们要保证批次内所有的句子长度相同,就需要把长度为2的句子填充为5。这样喂给RNN时,需要计算 2 × 5 = 10 2 \times 5 =10 2×5=10次,而实际真正需要的是 5 + 2 = 7 5+2=7 5+2=7次。
  • 得到的结果可能不准确。我们知道RNN取的是最后一个时间步的隐藏状态做为输出,虽然在填充时,一般是以全0的词向量填充,RNN神经元的权重乘以零不会影响最终的输出,但还有偏差 b b b,如果 b ≠ 0 b \neq 0 b=0,还是会影响到最后的输出。

    当然这个问题不大,主要是第1个问题,毕竟批次大小很大的时候影响还是不小的。

我们用图解进一步说明这个问题。假设某句子“Yes”只有一个单词,但是填充了多余的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差

那么我们正确的做法应该是怎么样呢?在上面这个例子,我们想要得到的仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示,如下图:

所以,Pytorch提供了pack_padded_sequence方法来压缩填充字符,加快RNN的计算效率。

pack_padded_sequence是如何压缩的?

那么它是如何做压缩的呢?举个例子,假如一个batch里有5个句子,长度分别是5、4、3、3、2、1。将它们按列压缩,在这个过程中删除了pad字符。所以你可以想象这样的训练过程:

  1. 第一个batch有5个单词,[I, I, This, No, Yes],它们被送入LSTM。
  2. 第二个batch有4个单词被送入LSTM。
  3. 以此类推,之后的batch长度逐渐减小,分别是3、3、2、1
    在这个过程中,pad字符被自然地忽略掉了。

pack_padded_sequence的参数含义

必备参数是句子向量embedded,以及每个句子长度的变量text_lengths。前者通常包含3个维度,即[批次大小、句子最大长度、单词向量长度](前两者顺序可换);后者通常是list类型,或者一维Tensor类型,包含了每个句子的长度。

  • batch_first表示输入的向量是batch维度优先的。
  • enforce_sorted代表输入的句子是否已经按照长度顺序排好,如果为False,那么函数估计会先按照长度排好,进行计算,再还原回原来的顺序。

这篇关于pytorch nn.utils.rnn.pack_padded_sequence 分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/228478

相关文章

kotlin中const 和val的区别及使用场景分析

《kotlin中const和val的区别及使用场景分析》在Kotlin中,const和val都是用来声明常量的,但它们的使用场景和功能有所不同,下面给大家介绍kotlin中const和val的区别,... 目录kotlin中const 和val的区别1. val:2. const:二 代码示例1 Java

Go标准库常见错误分析和解决办法

《Go标准库常见错误分析和解决办法》Go语言的标准库为开发者提供了丰富且高效的工具,涵盖了从网络编程到文件操作等各个方面,然而,标准库虽好,使用不当却可能适得其反,正所谓工欲善其事,必先利其器,本文将... 目录1. 使用了错误的time.Duration2. time.After导致的内存泄漏3. jsO

Spring事务中@Transactional注解不生效的原因分析与解决

《Spring事务中@Transactional注解不生效的原因分析与解决》在Spring框架中,@Transactional注解是管理数据库事务的核心方式,本文将深入分析事务自调用的底层原理,解释为... 目录1. 引言2. 事务自调用问题重现2.1 示例代码2.2 问题现象3. 为什么事务自调用会失效3

找不到Anaconda prompt终端的原因分析及解决方案

《找不到Anacondaprompt终端的原因分析及解决方案》因为anaconda还没有初始化,在安装anaconda的过程中,有一行是否要添加anaconda到菜单目录中,由于没有勾选,导致没有菜... 目录问题原因问http://www.chinasem.cn题解决安装了 Anaconda 却找不到 An

Spring定时任务只执行一次的原因分析与解决方案

《Spring定时任务只执行一次的原因分析与解决方案》在使用Spring的@Scheduled定时任务时,你是否遇到过任务只执行一次,后续不再触发的情况?这种情况可能由多种原因导致,如未启用调度、线程... 目录1. 问题背景2. Spring定时任务的基本用法3. 为什么定时任务只执行一次?3.1 未启用

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

C++ 各种map特点对比分析

《C++各种map特点对比分析》文章比较了C++中不同类型的map(如std::map,std::unordered_map,std::multimap,std::unordered_multima... 目录特点比较C++ 示例代码 ​​​​​​代码解释特点比较1. std::map底层实现:基于红黑

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

Spring、Spring Boot、Spring Cloud 的区别与联系分析

《Spring、SpringBoot、SpringCloud的区别与联系分析》Spring、SpringBoot和SpringCloud是Java开发中常用的框架,分别针对企业级应用开发、快速开... 目录1. Spring 框架2. Spring Boot3. Spring Cloud总结1. Sprin

Spring 中 BeanFactoryPostProcessor 的作用和示例源码分析

《Spring中BeanFactoryPostProcessor的作用和示例源码分析》Spring的BeanFactoryPostProcessor是容器初始化的扩展接口,允许在Bean实例化前... 目录一、概览1. 核心定位2. 核心功能详解3. 关键特性二、Spring 内置的 BeanFactory