白话RNN系列(六)

2024-09-06 05:38
文章标签 rnn 白话 系列

本文主要是介绍白话RNN系列(六),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

上文给出了一个LSTM使用的具体例子,但其中依旧存在一些东西说的不是很清楚明白,接下来,我们会针对LSTM使用中更加细致的一些东西,做一些介绍。

本人目前使用的基本都是TensorFlow进行开发。

lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden, forget_bias=1.0, name='basic_lstm_cell')
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)

假如我们以此种方式运行一个LSTM,毫无疑问,需要对outpus,和states 做一个深入的了解,因为LSTMCell只是神经网络的一层,通常我们后面都会有其他的全连接神经网络与之相连,完成分类或者其他的任务。

问题1:outputs和states的深入学习和理解

其次,我们上述代码使用了静态运行的方式,有没有其他的运行方式呢?

问题2:循环神经网络的其他运行方式

第三个问题,有多少种LSTMCell我们可以拿来用?还记得有一种叫做GRU的东西么?其与LSTMCell非常相像,但是结构比LSTMCell要简单,使用方式基本一致。

首先,我们从static_rnn方法的基本参数:

def static_rnn(cell, inputs, initial_state=None,dtype=None,sequence_ 
length=None, scope=None):

cell :即我们上面生成好的Cell类对象

inputs:输入数据,一定是list或者是二维张量,类比我们前面的x1,其就是28个元素的list

initital_state : 即初始状态,系统会初始化为全零的隐藏层状态,通常我们不用自行初始化

返回值有两个,一个是结果即outputs,一个是cell状态;我们只关注结果,而实际上结果也是一个list,输入是多少个时序,list里就会输出多少个元素。

类比于白话RNN系列(五)中,其outputs实际上是一个length=28的list,每个元素均为128 * 128 的张量;前面的128实际上代表的批次大小,而后面的128 则是代表我们输出的数据,这个128维的向量会与后续的全连接神经网络共同产出分类结果数据。

我们之所以取outputs[-1],原因在于其最是最后一个循环的输出,才是我们真正需要的数据。

而对于states,实际上是一个LSTMStateTuple类型的数据,本身是个元组;元组里面包含两个元素:c和h,c表示的就是最后时刻cell的内部状态值,h表示的就是最后时刻隐层状态值。

隐约有种感觉,outputs[-1]和states的第二个元素应该是相同的东西,都代表最后时刻的隐层状态值;事实证明,的确如此:

state:[[ 0.01089936  0.02731343  0.05962883 ...  0.11460225  0.093943930.08642814][ 0.05554791 -0.06378083  0.08341898 ...  0.09763484  0.014068910.00135677][-0.00513092 -0.0474266   0.07704069 ...  0.10713644  0.009927520.03562667]...[ 0.05851342 -0.0413289   0.08804315 ...  0.10628442  0.012927210.01195676][ 0.0014212  -0.00056065  0.04524097 ...  0.01346213 -0.00850293-0.00390543][ 0.01522607  0.03220217  0.06999128 ...  0.06230348  0.043055340.05264532]]
out-1:[[ 0.01089936  0.02731343  0.05962883 ...  0.11460225  0.093943930.08642814][ 0.05554791 -0.06378083  0.08341898 ...  0.09763484  0.014068910.00135677][-0.00513092 -0.0474266   0.07704069 ...  0.10713644  0.009927520.03562667]...[ 0.05851342 -0.0413289   0.08804315 ...  0.10628442  0.012927210.01195676][ 0.0014212  -0.00056065  0.04524097 ...  0.01346213 -0.00850293-0.00390543][ 0.01522607  0.03220217  0.06999128 ...  0.06230348  0.043055340.05264532]]

 这个结果可以通过在原有代码中嵌入状态输出和outputs[-1]输出得到:如下

        out = sess.run(outputs, feed_dict={x: batch_x, y: batch_y})sta = sess.run(states, feed_dict={x: batch_x, y: batch_y})print('state:' + str(sta[1]))print('out-1:' + str(out[-1]))

因此,稍微总结下:outputs实际上是一个时序list,循环体循环多少次,则有多少个元素;而每个元素为一个张量,形如100* 128, 其中100代表批次大小,而128代表隐藏层的状态。

同样,states代表了隐藏层状态,是一个元组;并且,states[1]和outputs[-1]实质是一样的。

OK,这里分析的outputs和states是与static_run方法紧密相连的,我们看下其他的运行方式,会有什么不同的效果:

除了静态运行外,循环神经网络还存在动态运行的方式:

outputs, _ = tf.nn.dynamic_rnn(gru, x, dtype=tf.float32)
outputs = tf.transpose(outputs, [1, 0, 2])

上图即动态运行方式,我们看下dynamic_rnn的参数:

def dynamic_rnn (cell, inputs, sequence_length=None, initial_state=None,dtype=None, parallel_iterations=None, swap_memory=False,time_major=False, scope=None):

参数大致与static_run都是一致的,但有几个地方不同:

1.inputs:输入数据,是一个张量,一般是三维张量,[batch_size,max_time,...]。其中batch_size表示一次的批次数量,max_time表示时间序列总数,后面是具体数据。

类比,我们这里不需要对x进行unstack操作,直接输入128 * 784 (前面128是批次大小,后面784 是序列总长度)即可。

对于其输出,我们关注outputs即可,当time_major为默认值False时,input的shape为[batch_size,max_time,...]。如果是True,shape为[max_time,batch_size,...]; 因此,我们可以显式定义time_major=True,来保证取出的outputs[-1]是我们需要使用到的隐藏层状态。

在LSTM的具体使用中,我们一方面需要关注自己输入的张量维度,同时也要谨慎关注我们输出的张量维度,二者结合在一起,能够起到更好的效果。

这篇关于白话RNN系列(六)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141157

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

flume系列之:查看flume系统日志、查看统计flume日志类型、查看flume日志

遍历指定目录下多个文件查找指定内容 服务器系统日志会记录flume相关日志 cat /var/log/messages |grep -i oom 查找系统日志中关于flume的指定日志 import osdef search_string_in_files(directory, search_string):count = 0

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训

Java基础回顾系列-第七天-高级编程之IO

Java基础回顾系列-第七天-高级编程之IO 文件操作字节流与字符流OutputStream字节输出流FileOutputStream InputStream字节输入流FileInputStream Writer字符输出流FileWriter Reader字符输入流字节流与字符流的区别转换流InputStreamReaderOutputStreamWriter 文件复制 字符编码内存操作流(

Java基础回顾系列-第五天-高级编程之API类库

Java基础回顾系列-第五天-高级编程之API类库 Java基础类库StringBufferStringBuilderStringCharSequence接口AutoCloseable接口RuntimeSystemCleaner对象克隆 数字操作类Math数学计算类Random随机数生成类BigInteger/BigDecimal大数字操作类 日期操作类DateSimpleDateForma

Java基础回顾系列-第三天-Lambda表达式

Java基础回顾系列-第三天-Lambda表达式 Lambda表达式方法引用引用静态方法引用实例化对象的方法引用特定类型的方法引用构造方法 内建函数式接口Function基础接口DoubleToIntFunction 类型转换接口Consumer消费型函数式接口Supplier供给型函数式接口Predicate断言型函数式接口 Stream API 该篇博文需重点了解:内建函数式

Java基础回顾系列-第二天-面向对象编程

面向对象编程 Java类核心开发结构面向对象封装继承多态 抽象类abstract接口interface抽象类与接口的区别深入分析类与对象内存分析 继承extends重写(Override)与重载(Overload)重写(Override)重载(Overload)重写与重载之间的区别总结 this关键字static关键字static变量static方法static代码块 代码块String类特

Java基础回顾系列-第六天-Java集合

Java基础回顾系列-第六天-Java集合 集合概述数组的弊端集合框架的优点Java集合关系图集合框架体系图java.util.Collection接口 List集合java.util.List接口java.util.ArrayListjava.util.LinkedListjava.util.Vector Set集合java.util.Set接口java.util.HashSetjava