tf.nn.dynamic_rnn的输出outputs和state含义

2024-02-07 14:48

本文主要是介绍tf.nn.dynamic_rnn的输出outputs和state含义,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

写在最前面:个人总结:

tf.nn.dynamic_rnn的返回值有两个:outputs和state

outputs:RNN/LSTM/GRU 的每个time_step都有一个输出,outputs把每个timestep的输出增加一个维度,并沿时间顺序在该维度串联。outputs.shape=[batch_size, max_time, hidden_size],要想取某个time_step的输出,只需要用对应的索引即可:output_timestep_k = outputs[:,k,:]

state:记录RNN/LSTM/GRU的最后一个time_step的cell状态,以LSTM为例子,state=(c,h),c代表最后一个step的Ct,h代表最后一个step的ht. 其中ht与outputs[:,-1,:]相等,是同一个东西。

---------------------------------------------------------------------------------------------------------------------------------------------------

以下是转载的具体内容,原文链接:https://blog.csdn.net/u010960155/article/details/81707498

-----------------------------------------------------------------------------------------------------------------------------------------------------

一、 tf.nn.dynamic_rnn的输出

tf.nn.dynamic_rnn的输入参数如下

 tf.nn.dynamic_rnn(cell,inputs,sequence_length=None,initial_state=None,dtype=None,parallel_iterations=None,swap_memory=False,time_major=False,scope=None)

 tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

    outputs. outputs是一个tensor
        如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)
        如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]
    state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state

那为什么state输出形状会有变化呢?state和output又有什么关系呢?
二、state含义

对于第一问题“state”形状为什么会发生变化呢?

我们以LSTM和GRU分别为tf.nn.dynamic_rnn的输入cell类型为例,当cell为LSTM,state形状为[2,batch_size, cell.output_size ];当cell为GRU时,state形状为[batch_size, cell.output_size ]。其原因是因为LSTM和GRU的结构本身不同,如下面两个图所示,这是LSTM的cell结构,每个cell会有两个输出:Ct 和 ht,上面这个图是输出Ct,代表哪些信息应该被记住哪些应该被遗忘; 下面这个图是输出ht,代表这个cell的最终输出,LSTM的state是由Ct 和 ht组成的。

当cell为GRU时,state就只有一个了,原因是GRU将Ct 和 ht进行了简化,将其合并成了ht,如下图所示,GRU将遗忘门和输入门合并成了更新门,另外cell不在有细胞状态cell state,只有hidden state。

对于第二个问题outputs和state有什么关系?

结论上来说,如果cell为LSTM,那 state是个tuple,分别代表Ct 和 ht,其中 ht与outputs中的对应的最后一个时刻的输出相等,假设state形状为[ 2,batch_size, cell.output_size ],outputs形状为 [ batch_size, max_time, cell.output_size ],那么state[ 1, batch_size, : ] == outputs[ batch_size, -1, : ];如果cell为GRU,那么同理,state其实就是 ht,state ==outputs[ -1 ]

 
三、实验

我们写点代码来具体感觉下outputs和state是什么,代码如下  

 import tensorflow as tfimport numpy as npdef dynamic_rnn(rnn_type='lstm'):# 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度X = np.random.randn(3, 6, 4)# 第二个输入的实际长度为4X[1, 4:] = 0#记录三个输入的实际步长X_lengths = [6, 4, 6]rnn_hidden_size = 5if rnn_type == 'lstm':cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)else:cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)outputs, last_states = tf.nn.dynamic_rnn(cell=cell,dtype=tf.float64,sequence_length=X_lengths,inputs=X)with tf.Session() as session:session.run(tf.global_variables_initializer())o1, s1 = session.run([outputs, last_states])print(np.shape(o1))print(o1)print(np.shape(s1))print(s1)if __name__ == '__main__':dynamic_rnn(rnn_type='lstm')

实验一:cell类型为LSTM,我们看看输出是什么样子,如下图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 2, 3, 5 ],其中state第一部分为c,代表cell state;第二部分为h,代表hidden state。可以看到hidden state 与 对应的outputs的最后一行是相等的。另外需要注意的是输入一共有三个序列,但第二个序列的长度只有4,可以看到outputs中对应的两行值都为0,所以hidden state对应的是最后一个不为0的部分。tf.nn.dynamic_rnn通过设置sequence_length来实现这一逻辑。  

 (3, 6, 5)[[[ 0.0146346  -0.04717453 -0.06930042 -0.06065602  0.02456717][-0.05580321  0.08770171 -0.04574306 -0.01652854 -0.04319528][ 0.09087799  0.03535907 -0.06974291 -0.03757408 -0.15553619][ 0.10003044  0.10654698  0.21004055  0.13792148 -0.05587583][ 0.13547596 -0.014292   -0.0211154  -0.10857875  0.04461256][ 0.00417564 -0.01985144  0.00050634 -0.13238986  0.14323784]][[ 0.04893576  0.14289175  0.17957205  0.09093887 -0.0507192 ][ 0.17696126  0.09929577  0.21185635  0.20386451  0.11664373][ 0.15658667  0.03952745 -0.03425637  0.00773833 -0.03546742][-0.14002582 -0.18578786 -0.08373584 -0.25964601  0.04090167][ 0.          0.          0.          0.          0.        ][ 0.          0.          0.          0.          0.        ]][[ 0.18564152  0.01531695  0.13752453  0.17188506  0.19555427][ 0.13703949  0.14272294  0.21313036  0.07417354  0.0477547 ][ 0.23021792  0.04455495  0.10204565  0.17159792  0.34148467][ 0.0386402   0.0387848   0.02134559  0.00110381  0.08414687][ 0.01386241 -0.02629686 -0.0733538  -0.03194245  0.13606553][ 0.01859433 -0.00585316 -0.04007138  0.03811594  0.21708331]]](2, 3, 5)LSTMStateTuple(c=array([[ 0.00909146, -0.03747076,  0.0008946 , -0.23459786,  0.29565899],[-0.18409266, -0.30463044, -0.28033809, -0.49032542,  0.12597639],[ 0.04494702, -0.01359631, -0.06706629,  0.06766361,  0.40794032]]), h=array([[ 0.00417564, -0.01985144,  0.00050634, -0.13238986,  0.14323784],[-0.14002582, -0.18578786, -0.08373584, -0.25964601,  0.04090167],[ 0.01859433, -0.00585316, -0.04007138,  0.03811594,  0.21708331]]))

实验二:cell类型为GRU,我们看看输出是什么样子,如下图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 3, 5 ]。可以看到 state 与 对应的outputs的最后一行是相等的。

(3, 6, 5)
[[[-0.05190962 -0.13519617  0.02045928 -0.0821183   0.28337528][ 0.0201574   0.03779418 -0.05092804  0.02958051  0.12232347][ 0.14884441 -0.26075898  0.1821795  -0.03454954  0.18424161][-0.13854156 -0.26565378  0.09567164 -0.03960079  0.14000589][-0.2605973  -0.39901657  0.12495693 -0.19295695  0.52423598][-0.21596414 -0.63051687  0.20837501 -0.31775378  0.77519457]][[-0.1979659  -0.30253523  0.0248779  -0.17981144  0.41815343][ 0.34481129 -0.05256187  0.1643036   0.00739746  0.27384158][ 0.49703664  0.22241165  0.27344766  0.00093435  0.09854949][ 0.23312444  0.156997    0.25482553  0.0138156  -0.02302272][ 0.          0.          0.          0.          0.        ][ 0.          0.          0.          0.          0.        ]][[-0.06401732  0.08605342 -0.03936866 -0.02287695  0.16947652][-0.1775206  -0.2801672  -0.0387468  -0.20264583  0.58125297][ 0.39408762 -0.44066425  0.25826641 -0.18851604  0.36172166][ 0.0536013  -0.29902928  0.08891931 -0.03930039  0.0743423 ][ 0.02304702 -0.0612499   0.09113458 -0.05169013  0.29876455][-0.06711324  0.014125   -0.05856332 -0.05632359 -0.00390189]]]
(3, 5)
[[-0.21596414 -0.63051687  0.20837501 -0.31775378  0.77519457][ 0.23312444  0.156997    0.25482553  0.0138156  -0.02302272][-0.06711324  0.014125   -0.05856332 -0.05632359 -0.00390189]]


 

这篇关于tf.nn.dynamic_rnn的输出outputs和state含义的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/688054

相关文章

顺序表之创建,判满,插入,输出

文章目录 🍊自我介绍🍊创建一个空的顺序表,为结构体在堆区分配空间🍊插入数据🍊输出数据🍊判断顺序表是否满了,满了返回值1,否则返回0🍊main函数 你的点赞评论就是对博主最大的鼓励 当然喜欢的小伙伴可以:点赞+关注+评论+收藏(一键四连)哦~ 🍊自我介绍   Hello,大家好,我是小珑也要变强(也是小珑),我是易编程·终身成长社群的一名“创始团队·嘉宾”

AI(文生语音)-TTS 技术线路探索学习:从拼接式参数化方法到Tacotron端到端输出

AI(文生语音)-TTS 技术线路探索学习:从拼接式参数化方法到Tacotron端到端输出 在数字化时代,文本到语音(Text-to-Speech, TTS)技术已成为人机交互的关键桥梁,无论是为视障人士提供辅助阅读,还是为智能助手注入声音的灵魂,TTS 技术都扮演着至关重要的角色。从最初的拼接式方法到参数化技术,再到现今的深度学习解决方案,TTS 技术经历了一段长足的进步。这篇文章将带您穿越时

状态模式state

学习笔记,原文链接 https://refactoringguru.cn/design-patterns/state 在一个对象的内部状态变化时改变其行为, 使其看上去就像改变了自身所属的类一样。 在状态模式中,player.getState()获取的是player的当前状态,通常是一个实现了状态接口的对象。 onPlay()是状态模式中定义的一个方法,不同状态下(例如“正在播放”、“暂停

如何将一个文件里不包含某个字符的行输出到另一个文件?

第一种: grep -v 'string' filename > newfilenamegrep -v 'string' filename >> newfilename 第二种: sed -n '/string/!'p filename > newfilenamesed -n '/string/!'p filename >> newfilename

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录 在深度学习项目中,目标检测是一项重要的任务。本文将详细介绍如何使用Detectron2进行目标检测模型的复现训练,涵盖训练数据准备、训练命令、训练日志分析、训练指标以及训练输出目录的各个文件及其作用。特别地,我们将演示在训练过程中出现中断后,如何使用 resume 功能继续训练,并将我们复现的模型与Model Zoo中的

第六章习题11.输出以下图形

🌏个人博客:尹蓝锐的博客 希望文章能够给到初学的你一些启发~ 如果觉得文章对你有帮助的话,点赞 + 关注+ 收藏支持一下笔者吧~ 1、题目要求: 输出以下图形

LibSVM学习(五)——分界线的输出

对于学习SVM人来说,要判断SVM效果,以图形的方式输出的分解线是最直观的。LibSVM自带了一个可视化的程序svm-toy,用来输出类之间的分界线。他是先把样本文件载入,然后进行训练,通过对每个像素点的坐标进行判断,看属于哪一类,就附上那类的颜色,从而使类与类之间形成分割线。我们这一节不讨论svm-toy怎么使用,因为这个是“傻瓜”式的,没什么好讨论的。这一节我们主要探讨怎么结合训练结果文件

下载/保存/读取 文件,并转成流输出

最近对文件的操作又熟悉了下;现在记载下来:学习在于 坚持!!!不以细小而不为。 实现的是:文件的下载、文件的保存到SD卡、文件的读取输出String 类型、最后是文件转换成流输出;一整套够用了; 重点: 1:   操作网络要记得开线程; 2:更新网络获取的数据 切记用Handler机制; 3:注意代码的可读性(这里面只是保存到SD卡,在项目中切记要对SD卡的有无做判断,然后再获取路径!)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

彻底解决win10系统Tomcat10控制台输出中文乱码

彻底解决Tomcat10控制台输出中文乱码 首先乱码问题的原因通俗的讲就是读的编码格式和写的解码格式不一致,比如最常见的两种中文编码UTF-8和GBK,UTF-8一个汉字占三个字节,GBK一个汉字占两个字节,所以当编码与解码格式不一致时,输出端当然无法识别这是啥,所以只能以乱码代替。 值得一提的是GBK不是国家标准编码,常用的国标有两,一个是GB2312,一个是GB18030 GB1