白话RNN系列(七)

2024-09-06 05:38
文章标签 系列 rnn 白话

本文主要是介绍白话RNN系列(七),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文,探讨下LSTM的一些高级应用,比如双向LSTM。

前面的探讨过程中, 我们使用到的RNN或者LSTM都是单向的,即按照时间顺序排列的一维序列;而在实际应用中,双向的RNN由于考虑到更充足的上下文,往往能起到更好的效果:

Bi-RNN又叫双向RNN,是采用了两个方向的RNN网络。
RNN网络擅长的是对于连续数据的处理,既然是连续的数据规律,我们不仅可以学习它的正向规律,还可以学习它的反向规律。这样将正向和反向结合的网络,会比单向的循环网络有更高的拟合度。例如,预测一个语句中缺失的词语,则需要根据上下文来进行预测。
双向RNN的处理过程与单向的RNN非常类似,就是在正向传播的基础上再进行一次反向传播,而且这两个都连接着一个输出层。这个结构提供给输出层输入序列中,每一个点完整的过去和未来的上下文信息。下图所示为一个沿着时间展开的双向循环神经网络:

在按照时间序列正向运算完之后,网络又从时间的最后一项反向地运算一遍,假如上图中我们定义T=3,即把t3时刻的输入与默认值0一起生成反向的out3,把反向out3当成t2时刻的输入与原来的t2时刻输入一起生成反向out2;依此类推,直到第一个时序数据。

双向循环网络的输出是2个,正向一个,反向一个。最终会把输出结果通过concat并联在一起,然后交给后面的层来处理。例如,数据输入[batch,nhidden],输出就会变成[batch,nhidden×2]。

结构看起来很简单,我们还是找个具体的例子,来对双向RNN进行一个透彻的分析;这里,我们使用单层动态双向RNN网络对MNIST数据集进行分类:

n_input = 28  # MNIST data 输入 (img shape: 28*28)
n_steps = 28  # timesteps
n_hidden = 128  # hidden layer num of features
n_classes = 10  # MNIST 列别 (0-9 ,一共10类)

初始化的一些变量,与原先的LSTM并无很大区别:

区别主要在如下地方:

x1 = tf.unstack(x, n_steps, 1)
# 我们定义了前向的LSTMCell, 其隐藏层包含了128个节点
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# 同时,定义了反向的LSTMCell,其隐藏层同样包含了128个节点
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

在定义LSTMCell的时候,我们需要定义两遍LSTMCell,其隐藏层都拥有128个节点:

outputs, output_states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, x,dtype=tf.float32)

训练方式稍有不同,可以看到,我们这里使用了bidirectional_dynamic_rnn的训练方式;同时输入的x是128 * 28 * 28 的数据,并未按照时间序列进行切分:

        out = sess.run(outputs_first, feed_dict={x: batch_x, y: batch_y})print('out-0:'+str(out[0].shape))print('out-1:'+str(out[1].shape))

我们在训练过程中,输出outputs两个元素的形状,发现:

out-0:(100, 28, 128)
out-1:(100, 28, 128)

前面定义的100位batch_size即每个批次数据量的大小;中间的28即循环次数;128 即隐藏层神经元的数目:

outputs = tf.concat(outputs_first, 2)
outputs = tf.transpose(outputs, [1, 0, 2])

通常,我们会把两次输出进行拼接,生成最后的输入。

余下,都和单向的LSTM完全一致了:

总结下,输出的outputs是前向和后向分开的。这种方法最原始也最灵活,但要注意,一定要把两个输出结果进行融合,我们可以采用拼接的形式,自然也可以采用其他的向量相加的方式等来实现自己的目的。

下面,我们再使用静态双向RNN把上述操作实现一遍,看看其与静态双向RNN有什么本质的区别:

lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# 反向cell
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x1,dtype=tf.float32)
print(outputs[0].shape,len(outputs))
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)

很明显两个区别:

1.输出的形式,静态需要对输入数据进行unstack操作,分割成按照时间顺序排列的数据(而动态执行不需要此种方式,更加简洁)

2.其输出outputs[-1]直接是拼接好的数据,我们在执行过程中可以看到:

(?, 256) 28

毫无疑问,这种方式没有动态执行方式灵活。

当然,我们在实现过程中,还可以采用多层RNN的方式来实现自己的目的,在此处不赘述了。

这篇关于白话RNN系列(七)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141158

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

flume系列之:查看flume系统日志、查看统计flume日志类型、查看flume日志

遍历指定目录下多个文件查找指定内容 服务器系统日志会记录flume相关日志 cat /var/log/messages |grep -i oom 查找系统日志中关于flume的指定日志 import osdef search_string_in_files(directory, search_string):count = 0

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训

Java基础回顾系列-第七天-高级编程之IO

Java基础回顾系列-第七天-高级编程之IO 文件操作字节流与字符流OutputStream字节输出流FileOutputStream InputStream字节输入流FileInputStream Writer字符输出流FileWriter Reader字符输入流字节流与字符流的区别转换流InputStreamReaderOutputStreamWriter 文件复制 字符编码内存操作流(

Java基础回顾系列-第五天-高级编程之API类库

Java基础回顾系列-第五天-高级编程之API类库 Java基础类库StringBufferStringBuilderStringCharSequence接口AutoCloseable接口RuntimeSystemCleaner对象克隆 数字操作类Math数学计算类Random随机数生成类BigInteger/BigDecimal大数字操作类 日期操作类DateSimpleDateForma

Java基础回顾系列-第三天-Lambda表达式

Java基础回顾系列-第三天-Lambda表达式 Lambda表达式方法引用引用静态方法引用实例化对象的方法引用特定类型的方法引用构造方法 内建函数式接口Function基础接口DoubleToIntFunction 类型转换接口Consumer消费型函数式接口Supplier供给型函数式接口Predicate断言型函数式接口 Stream API 该篇博文需重点了解:内建函数式

Java基础回顾系列-第二天-面向对象编程

面向对象编程 Java类核心开发结构面向对象封装继承多态 抽象类abstract接口interface抽象类与接口的区别深入分析类与对象内存分析 继承extends重写(Override)与重载(Overload)重写(Override)重载(Overload)重写与重载之间的区别总结 this关键字static关键字static变量static方法static代码块 代码块String类特

Java基础回顾系列-第六天-Java集合

Java基础回顾系列-第六天-Java集合 集合概述数组的弊端集合框架的优点Java集合关系图集合框架体系图java.util.Collection接口 List集合java.util.List接口java.util.ArrayListjava.util.LinkedListjava.util.Vector Set集合java.util.Set接口java.util.HashSetjava