#####好好好#####深度学习笔记——Attention Model(注意力模型)学习总结

本文主要是介绍#####好好好#####深度学习笔记——Attention Model(注意力模型)学习总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习里的Attention model其实模拟的是人脑的注意力模型,举个例子来说,当我们观赏一幅画时,虽然我们可以看到整幅画的全貌,但是在我们深入仔细地观察时,其实眼睛聚焦的就只有很小的一块,这个时候人的大脑主要关注在这一小块图案上,也就是说这个时候人脑对整幅图的关注并不是均衡的,是有一定的权重区分的。这就是深度学习里的Attention Model的核心思想。

AM刚开始也确实是应用在图像领域里的,AM在图像处理领域取得了非常好的效果!于是,就有人开始研究怎么将AM模型引入到NLP领域。最有名的当属“Neural machine translation by jointly learning to align and translate”这篇论文了,这篇论文最早提出了Soft Attention Model,并将其应用到了机器翻译领域。后续NLP领域使用AM模型的文章一般都会引用这篇文章(目前引用量已经上千了!!!)

如下图所示,机器翻译主要使用的是Encoder-Decoder模型,在Encoder-Decoder模型的基础上引入了AM,取得了不错的效果:


Soft Attention Model:

这里其实是上面图的拆解,我们前面说过,“Neural machine translation by jointly learning to align and translate”这篇论文提出了soft Attention Model,并将其应用到了机器翻译上面。其实,所谓Soft,意思是在求注意力分配概率分布的时候,对于输入句子X中任意一个单词都给出个概率,是个概率分布。

即上图中的ci是对Encoder中每一个单词都要计算一个注意力概率分布,然后加权得到的。如下图所示:


其实有Soft AM,对应也有一个Hard AM。既然Soft是给每个单词都赋予一个单词对齐概率,那么如果不这样做,直接从输入句子里面找到某个特定的单词,然后把目标句子单词和这个单词对齐,而其它输入句子中的单词硬性地认为对齐概率为0,这就是Hard Attention Model的思想。Hard AM在图像里证明有用,但是在文本里面用处不大,因为这种单词一一对齐明显要求太高,如果对不齐对后续处理负面影响很大。

但是,斯坦福大学的一篇paper“Effective Approaches to Attention-based Neural Machine Translation”提出了一个混合Soft AM 和Hard AM的模型,论文中,他们提出了两种模型:Global Attention Model和Local Attention Model,Global Attention Model其实就是Soft Attention Model,Local Attention Model本质上是Soft AM和 Hard AM的一个混合。一般首先预估一个对齐位置Pt,然后在Pt左右大小为D的窗口范围来取类似于Soft AM的概率分布。

Global Attention Model和Local Attention Model

Global AM其实就是soft AM,Decoder的过程中,每一个时间步的Context vector需要计算Encoder中每一个单词的注意力权重,然后加权得到。



Local AM则是首先找到一个对其位置,然后在对其位置左右一个窗口内来计算注意力权重,最终加权得到Context vector。这其实是Soft AM和Hard AM的一个混合折中。


静态AM

其实还有一种AM叫做静态AM。所谓静态AM,其实是指对于一个文档或者句子,计算每个词的注意力概率分布,然后加权得到一个向量来代表这个文档或者句子的向量表示。跟soft AM的区别是,soft AM在Decoder的过程中每一次都需要重新对所有词计算一遍注意力概率分布,然后加权得到context vector,但是静态AM只计算一次得到句子的向量表示即可。(这其实是针对于不同的任务而做出的改变)



强制前向AM

Soft AM在逐步生成目标句子单词的时候,是由前向后逐步生成的,但是每个单词在求输入句子单词对齐模型时,并没有什么特殊要求。强制前向AM则增加了约束条件:要求在生成目标句子单词时,如果某个输入句子单词已经和输出单词对齐了,那么后面基本不太考虑再用它了,因为输入和输出都是逐步往前走的,所以看上去类似于强制对齐规则在往前走。


看了这么多AM模型以及变种,那么我们来看一看AM模型具体怎么实现,涉及的公式都是怎样的。

我们知道,注意力机制是在序列到序列模型中用于注意编码器状态的最常用方法,它同时还可用于回顾序列模型的过去状态。使用注意力机制,系统能基于隐藏状态 s_1,...,s_m 而获得环境向量(context vector)c_i,这些环境向量可以和当前的隐藏状态 h_i 一起实现预测。环境向量 c_i 可以由前面状态的加权平均数得出,其中状态所加的权就是注意力权重 a_i:


注意力函数 f_att(h_i,s_j) 计算的是目前的隐藏状态 h_i 和前面的隐藏状态 s_j 之间的非归一化分配值。

而实际上,注意力函数也有很多种变体。接下来我们将讨论四种注意力变体:加性注意力(additive attention)、乘法(点积)注意力(multiplicative attention)、自注意力(self-attention)和关键值注意力(key-value attention)。

加性注意力(additive attention)

加性注意力是最经典的注意力机制 (Bahdanau et al., 2015) [15],它使用了有一个隐藏层的前馈网络(全连接)来计算注意力的分配:


也就是:


乘法(点积)注意力(multiplicative attention)

乘法注意力(Multiplicative attention)(Luong et al., 2015) [16] 通过计算以下函数而简化了注意力操作:


加性注意力和乘法注意力在复杂度上是相似的,但是乘法注意力在实践中往往要更快速、具有更高效的存储,因为它可以使用矩阵操作更高效地实现。两个变体在低维度 d_h 解码器状态中性能相似,但加性注意力机制在更高的维度上性能更优。


自注意力(self-attention)

注意力机制不仅能用来处理编码器或前面的隐藏层,它同样还能用来获得其他特征的分布,例如阅读理解任务中作为文本的词嵌入 (Kadlec et al., 2017) [37]。然而,注意力机制并不直接适用于分类任务,因为这些任务并不需要情感分析(sentiment analysis)等额外的信息。在这些模型中,通常我们使用 LSTM 的最终隐藏状态或像最大池化和平均池化那样的聚合函数来表征句子。

自注意力机制(Self-attention)通常也不会使用其他额外的信息,但是它能使用自注意力关注本身进而从句子中抽取相关信息 (Lin et al., 2017) [18]。自注意力又称作内部注意力,它在很多任务上都有十分出色的表现,比如阅读理解 (Cheng et al., 2016) [38]、文本继承 (textual entailment/Parikh et al., 2016) [39]、自动文本摘要 (Paulus et al., 2017) [40]。



关键值注意力(key-value attention)

关键值注意力 (Daniluk et al., 2017) [19] 是最近出现的注意力变体机制,它将形式和函数分开,从而为注意力计算保持分离的向量。它同样在多种文本建模任务 (Liu & Lapata, 2017) [41] 中发挥了很大的作用。具体来说,关键值注意力将每一个隐藏向量 h_i 分离为一个键值 k_i 和一个向量 v_i:[k_i;v_i]=h_i。键值使用加性注意力来计算注意力分布 a_i:


其中 L 为注意力窗体的长度,I 为所有单元为 1 的向量。然后使用注意力分布值可以求得环境表征 c_i:


其中环境向量 c_i 将联合现阶段的状态值 v_i 进行预测。


最后的最后,再加一个自己用keras实现的简单的静态AM(自注意力)层的代码吧:

[python]  view plain copy
  1. from keras import backend as K  
  2. from keras.engine.topology import Layer  
  3. from keras import initializers, regularizers, constraints  
  4.   
  5. class Attention_layer(Layer):  
  6.     """ 
  7.         Attention operation, with a context/query vector, for temporal data. 
  8.         Supports Masking. 
  9.         Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf] 
  10.         "Hierarchical Attention Networks for Document Classification" 
  11.         by using a context vector to assist the attention 
  12.         # Input shape 
  13.             3D tensor with shape: `(samples, steps, features)`. 
  14.         # Output shape 
  15.             2D tensor with shape: `(samples, features)`. 
  16.         :param kwargs: 
  17.         Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True. 
  18.         The dimensions are inferred based on the output shape of the RNN. 
  19.         Example: 
  20.             model.add(LSTM(64, return_sequences=True)) 
  21.             model.add(AttentionWithContext()) 
  22.         """  
  23.   
  24.     def __init__(self,  
  25.                  W_regularizer=None, b_regularizer=None,  
  26.                  W_constraint=None, b_constraint=None,  
  27.                  bias=True, **kwargs):  
  28.   
  29.         self.supports_masking = True  
  30.         self.init = initializers.get('glorot_uniform')  
  31.   
  32.         self.W_regularizer = regularizers.get(W_regularizer)  
  33.         self.b_regularizer = regularizers.get(b_regularizer)  
  34.   
  35.         self.W_constraint = constraints.get(W_constraint)  
  36.         self.b_constraint = constraints.get(b_constraint)  
  37.   
  38.         self.bias = bias  
  39.         super(Attention_layer, self).__init__(**kwargs)  
  40.   
  41.     def build(self, input_shape):  
  42.         assert len(input_shape) == 3  
  43.   
  44.         self.W = self.add_weight((input_shape[-1], input_shape[-1],),  
  45.                                  initializer=self.init,  
  46.                                  name='{}_W'.format(self.name),  
  47.                                  regularizer=self.W_regularizer,  
  48.                                  constraint=self.W_constraint)  
  49.         if self.bias:  
  50.             self.b = self.add_weight((input_shape[-1],),  
  51.                                      initializer='zero',  
  52.                                      name='{}_b'.format(self.name),  
  53.                                      regularizer=self.b_regularizer,  
  54.                                      constraint=self.b_constraint)  
  55.   
  56.         super(Attention_layer, self).build(input_shape)  
  57.   
  58.     def compute_mask(self, input, input_mask=None):  
  59.         # do not pass the mask to the next layers  
  60.         return None  
  61.   
  62.     def call(self, x, mask=None):  
  63.         uit = K.dot(x, self.W)  
  64.   
  65.         if self.bias:  
  66.             uit += self.b  
  67.   
  68.         uit = K.tanh(uit)  
  69.   
  70.         a = K.exp(uit)  
  71.   
  72.         # apply mask after the exp. will be re-normalized next  
  73.         if mask is not None:  
  74.             # Cast the mask to floatX to avoid float64 upcasting in theano  
  75.             a *= K.cast(mask, K.floatx())  
  76.   
  77.         # in some cases especially in the early stages of training the sum may be almost zero  
  78.         # and this results in NaN's. A workaround is to add a very small positive number to the sum.  
  79.         # a /= K.cast(K.sum(a, axis=1, keepdims=True), K.floatx())  
  80.         a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())  
  81.         print a  
  82.         # a = K.expand_dims(a)  
  83.         print x  
  84.         weighted_input = x * a  
  85.         print weighted_input  
  86.         return K.sum(weighted_input, axis=1)  
  87.   
  88.     def compute_output_shape(self, input_shape):  
  89.         return (input_shape[0], input_shape[-1])  

这篇关于#####好好好#####深度学习笔记——Attention Model(注意力模型)学习总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/967656

相关文章

C++ 右值引用(rvalue references)与移动语义(move semantics)深度解析

《C++右值引用(rvaluereferences)与移动语义(movesemantics)深度解析》文章主要介绍了C++右值引用和移动语义的设计动机、基本概念、实现方式以及在实际编程中的应用,... 目录一、右值引用(rvalue references)与移动语义(move semantics)设计动机1

C# List.Sort四种重载总结

《C#List.Sort四种重载总结》本文详细分析了C#中List.Sort()方法的四种重载形式及其实现原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友... 目录1. Sort方法的四种重载2. 具体使用- List.Sort();- IComparable

SpringBoot项目整合Netty启动失败的常见错误总结

《SpringBoot项目整合Netty启动失败的常见错误总结》本文总结了SpringBoot集成Netty时常见的8类问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、端口冲突问题1. Tomcat与Netty端口冲突二、主线程被阻塞问题1. Netty启动阻

SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)

《SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)》本文总结了SpringBoot项目整合Kafka启动失败的常见错误,包括Kafka服务器连接问题、序列化配置错误、依赖配置问题、... 目录一、Kafka服务器连接问题1. Kafka服务器无法连接2. 开发环境与生产环境网络不通二、序

SQL 注入攻击(SQL Injection)原理、利用方式与防御策略深度解析

《SQL注入攻击(SQLInjection)原理、利用方式与防御策略深度解析》本文将从SQL注入的基本原理、攻击方式、常见利用手法,到企业级防御方案进行全面讲解,以帮助开发者和安全人员更系统地理解... 目录一、前言二、SQL 注入攻击的基本概念三、SQL 注入常见类型分析1. 基于错误回显的注入(Erro

python3中正则表达式处理函数用法总结

《python3中正则表达式处理函数用法总结》Python中的正则表达式是一个强大的文本处理工具,用于匹配、查找、替换等操作,在Python中正则表达式的操作主要通过内置的re模块来实现,这篇文章主要... 目录前言re.match函数re.search方法re.match 与 re.search的区别检索

Java领域模型示例详解

《Java领域模型示例详解》本文介绍了Java领域模型(POJO/Entity/VO/DTO/BO)的定义、用途和区别,强调了它们在不同场景下的角色和使用场景,文章还通过一个流程示例展示了各模型如何协... 目录Java领域模型(POJO / Entity / VO/ DTO / BO)一、为什么需要领域模

深入理解Redis线程模型的原理及使用

《深入理解Redis线程模型的原理及使用》Redis的线程模型整体还是多线程的,只是后台执行指令的核心线程是单线程的,整个线程模型可以理解为还是以单线程为主,基于这种单线程为主的线程模型,不同客户端的... 目录1 Redis是单线程www.chinasem.cn还是多线程2 Redis如何保证指令原子性2.

Java枚举类型深度详解

《Java枚举类型深度详解》Java的枚举类型(enum)是一种强大的工具,它不仅可以让你的代码更简洁、可读,而且通过类型安全、常量集合、方法重写和接口实现等特性,使得枚举在很多场景下都非常有用,本文... 目录前言1. enum关键字的使用:定义枚举类型什么是枚举类型?如何定义枚举类型?使用枚举类型:2.

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、