白话RNN系列(四)

2024-09-06 05:38
文章标签 rnn 白话 系列

本文主要是介绍白话RNN系列(四),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文,谈谈RNN的一个变种,也是目前使用比较广泛的神经网络LSTM,我们首先描述下LSTM的基本结构,然后给出一个具体的使用LSTM的例子,帮助大家尽快掌握LSTM的原理和基本使用方法;

这可能是一张大家熟悉地不能再熟悉的图片了。

我们可以将其与RNN的基本结构进行对比:

è¿éåå¾çæè¿°

 我们可以看到区别:RNN中,每个循环体会产生一份输出,即隐藏状态;最终输出由此隐藏状态产出,同时,隐藏状态会保留,并与下一次输入结合起来,作为下一时刻的输入。

而在LSTM中,每个循环体会产生两份输出:一份是隐藏层状态,另一份则是当前细胞的状态,称之为Cell State,而这个Cell State就是LSTM保持长期记忆的关键。

好,我们接下来深入分析下LSTM的结构。

关键字:细胞状态,遗忘门,输入门,输出门。

以下的分析来自于李金洪老师的书,《深入学习之TensorFlow》

1.遗忘门

上图中的A ,是我们的循环体,其实质由三部分组成,即遗忘门,输入门,输出门三块组成,我们把这幅图拆分出来,仔细看下三个门的结构及其用途:

遗忘门即图中黑色实线部分,右侧即相关的计算公式。

遗忘门的作用在于:决定什么时候需要把以前的状态忘记,并且其做法是通过训练,依赖当前输入和上一层隐藏状态来决定模型从细胞状态中丢弃什么信息。

而从图中来看,遗忘门的实质是,通过当前输入x(t)和上一层的隐藏状态h(t-1),通过一个简单的全连接,生成一个f(t),这是一个0~1之间的数值,0表示全部忘记,1表示全部保留。

例如一个语言模型的例子,假设细胞状态会包含当前主语的性别,于是根据这个状态便可以选择正确的代词。当我们看到新的主语时,应该把新的主语在记忆中更新。该门的功能就是先去记忆中找到以前那个旧的主语(并没有真正忘掉操作,只是找到而已)。

2.输入门

中间这块,是输入门,顾名思义,其决定了我们的输入,当然,只是部分输入(其还需要与细胞状态结合,得到我们真正的输入)。

右边的公式也很容易看明白,初步的i(t)通过h(t-1)和x(t)获得;同时,这里会根据h(t-1)和x(t)生成一个新的细胞状态;如果前面的遗忘门决定把过往的细胞状态全部遗忘的话,接下来传递的就是此时新建的细胞状态了。

这里,我们会发现,新的细胞状态的更新,是遗忘门与输入门一起进行的:忘记门找到了需要忘掉的信息ft后,再将它与旧状态相乘,丢弃掉确定需要丢弃的信息。再将结果加上it×Ct使细胞状态获得新的信息,这样就完成了细胞状态的更新

上图展示的即是通过遗忘门和输入门对细胞状态进行更新的步骤。

3.输出门

可以看出,模型通过一个Sigmoid层来确定哪部分的信息将输出,接着把细胞状态通过Tanh进行处理(得到一个在-1~1之间的值)并将它和Sigmoid门的输出相乘,得出最终想要输出的那部分,例如在语言模型中,假设已经输入了一个代词,便会计算出需要输出一个与动词相关的信息。

一般来说,出现激活函数的地方,都会自带一个全连接的神经网络,那么,LSTM每个细胞,就自带有五个全连接网络,这五个全连接组成LSTM的Cell的基本结构。

对于LSTM的前向传播和后向传播,我们此处先不分析;我们先给出个例子分析,到底如何正确地使用LSTM,但是在分析之前,我们从输入向量的维度,到输出向量的维度做一个全盘分析,了解清楚每个向量的维度:

def __init__(self, num_units,use_peepholes=False, cell_clip=None,initializer=None, num_proj=None, proj_clip=None,num_unit_shards=None, num_proj_shards=None,forget_bias=1.0, state_is_tuple=True,activation=None, reuse=None, name=None, dtype=None, **kwargs):

上面的代码rnn_cell_impl中截取而来,我们在使用过程中,会对LSTMCell进行初始化:

n_hidden = 128
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden, forget_bias=1.0,name='basic_lstm_cell')

ok,我这里设定n_hidden=128,可以看出,因为_init_方法中只有一个参数未定义默认值,即num_units,所以我们填充的n_hidden=128即num_units的值,那么问题来了,我们设置的n_units代表了什么?

num_units: int, The number of units in the LSTM cell.

官方注释中是这么说的,LSTMCell中的单元数目,LSTM中有什么单元?只有一个又一个的全连接网络,哦了,这个其实就是隐藏层神经元的数目,虽然LSTMCell中有多个全连接神经网络,但实际上这些隐藏层神经网络的隐藏层节点数是统一设定的。

这个很好理解,为了保持向量最终的一致性。

这里的分析,跟RNN的分析有相似的地方,我们套用一下:

1:遗忘门(Forget Gate)

该图片的与上面图片完全一致:

我们假定输入的x(t)为一个one-hot向量,维度为100 * 1;而h(t-1)的维度则是128 * 1 ;我们把两个列向量拼接在一起,其维度为228 * 1, 而W(f)为了完成左乘,其维度为128 * 228 ,得到128 * 1的隐藏层输出;列向量经过激活函数,每个元素都相当于做了归一化,但依旧保持了128 * 1的维度。

我们会发现,f(t)实际上也是个向量,该向量的每个维度数值经过sigmoid函数,取值都在[0,1]范围内,正好跟前面对应上,如果sigmoid化后为1,则代表全部保留;如果sigmoid化后为0,则全部抛弃

2:输入门

毫无疑问,这里得到的i(t)维度仍然是128 * 1的列向量,但是经过sigmoid化后,其每个维度的值都局限在[0,1]范围内;而我们初始化得到的C(t)同样是128 * 1的列向量;以此类推,上面一直传递的细胞状态,同样是128 * 1的维度。

前文说过,输入门与遗忘门相互配合,得到最终的输入门的更新;我们注意这里的* 号,实际表达的是对应元素的乘法,即我们128 * 1 的列向量乘以128 * 1列向量,得到的依旧是128 * 1维的列向量;这样才能解释我们为什么输出的细胞状态依旧是128 * 1的列向量。

并且,可以看出,我们的细胞状态,由两部分组成:部分旧的细胞状态,再加上我们此次细胞初始化的状态。

3.输出门

è¿éåå¾çæè¿°

ok,终于到了输出门的状态了;

可以看到,o(t)同样是经过一个全连接的神经网络产生的,其隐藏层拥有128个神经元,得到的输出o(t)是128 * 1维的列向量;同样,我们地C(t)也是128 * 1的列向量;因此,得到最终隐藏层的输出也是128 * 1维度的列向量。

稍微总结下:

1.遗忘门与输入门配合,对细胞状态进行更新

2.隐藏状态的输出,需要输入门与细胞状态的共同作用。

这里对各个牵涉到的向量进行具体分析,是为了方便我们对LSTM神经网络的使用,下一篇文章会详细介绍一个可用的LSTM网络,帮助大家能够快速、正确使用LSTM来完成自己的任务。

这篇关于白话RNN系列(四)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141155

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

flume系列之:查看flume系统日志、查看统计flume日志类型、查看flume日志

遍历指定目录下多个文件查找指定内容 服务器系统日志会记录flume相关日志 cat /var/log/messages |grep -i oom 查找系统日志中关于flume的指定日志 import osdef search_string_in_files(directory, search_string):count = 0

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训

Java基础回顾系列-第七天-高级编程之IO

Java基础回顾系列-第七天-高级编程之IO 文件操作字节流与字符流OutputStream字节输出流FileOutputStream InputStream字节输入流FileInputStream Writer字符输出流FileWriter Reader字符输入流字节流与字符流的区别转换流InputStreamReaderOutputStreamWriter 文件复制 字符编码内存操作流(

Java基础回顾系列-第五天-高级编程之API类库

Java基础回顾系列-第五天-高级编程之API类库 Java基础类库StringBufferStringBuilderStringCharSequence接口AutoCloseable接口RuntimeSystemCleaner对象克隆 数字操作类Math数学计算类Random随机数生成类BigInteger/BigDecimal大数字操作类 日期操作类DateSimpleDateForma

Java基础回顾系列-第三天-Lambda表达式

Java基础回顾系列-第三天-Lambda表达式 Lambda表达式方法引用引用静态方法引用实例化对象的方法引用特定类型的方法引用构造方法 内建函数式接口Function基础接口DoubleToIntFunction 类型转换接口Consumer消费型函数式接口Supplier供给型函数式接口Predicate断言型函数式接口 Stream API 该篇博文需重点了解:内建函数式

Java基础回顾系列-第二天-面向对象编程

面向对象编程 Java类核心开发结构面向对象封装继承多态 抽象类abstract接口interface抽象类与接口的区别深入分析类与对象内存分析 继承extends重写(Override)与重载(Overload)重写(Override)重载(Overload)重写与重载之间的区别总结 this关键字static关键字static变量static方法static代码块 代码块String类特

Java基础回顾系列-第六天-Java集合

Java基础回顾系列-第六天-Java集合 集合概述数组的弊端集合框架的优点Java集合关系图集合框架体系图java.util.Collection接口 List集合java.util.List接口java.util.ArrayListjava.util.LinkedListjava.util.Vector Set集合java.util.Set接口java.util.HashSetjava