动画图解Attention机制,让你一看就明白

2024-06-21 09:32

本文主要是介绍动画图解Attention机制,让你一看就明白,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Raimi Karim

编译:ronghuaiyang

导读

之前分享了几次attention的文章,感觉意犹未尽,这次用GIF来解释Attention机制,让人一看就明白,并解释如何用在Google Translate之类的机器翻译场景中。

640?wx_fmt=png

几十年来,统计机器翻译一直是占主导地位的翻译模型,直到神经机器翻译 (NMT)的诞生。NMT是一种新兴的机器翻译方法,它试图构建和训练单个大型的神经网络,来读取输入文本并输出对应的翻译。

NMT的先驱是Kalchbrenner and Blunsom (2013)Sutskever et. al (2014)Cho. et. al (2014b),其中比较熟悉的框架是来自Sutskever et. al.的序列到序列(seq2seq)模型。本文会基于seq2seq框架,并描述如何在此基础上构建注意力机制。

640?wx_fmt=png

Fig. 0.1: 输入长度为4的seq2seq


seq2seq的想法是,有两个递归神经网络(RNNs),使用encoder-decoder架构:一个接一个读取输入单词,得到一个固定的维度的向量表示(编码器),而且,以这些输入为条件,使用另一个RNN解码器一个一个地提取输出单词。

640?wx_fmt=png

Fig. 0.2: 输入序列长度为64的seq2seq


seq2seq的问题是,解码器从编码器接收到的唯一信息是编码器的最后隐藏状态(图0.1中的两个红色节点),这是一个向量表示,类似于输入序列的数值摘要。因此,对于长输入文本(图0.2),我们期望解码器只使用这一个 向量表示(希望它“充分描述输入序列”)来输出翻译是不现实的。这可能会导致灾难性的遗忘。这段有100个单词。你能在问号后面把这段话翻译成你知道的另一种语言吗?

如果我们做不到,那么我们就不应该对解码器如此残忍。那么,如果不光给一个向量表示,同时还给解码器一个来自每个编码器时间步长的向量表示,这样它就可以做出具有充足信息的翻译了,这个想法怎么样?让我们进入注意力机制

640?wx_fmt=png

Fig 0.3:添加注意力机制作为编码器和解码器之间的接口。在这里,第一个解码器时间步在给出第一个翻译单词之前接收来自编码器的信息。


注意力机制是编码器和解码器之间的接口,它向解码器提供来自每个编码器隐藏状态的信息(图0.3中红色隐藏状态除外)。通过这个设置,模型能够选择性地关注输入序列的有用部分,从而学习它们之间的“对齐”。这有助于模型有效地处理长输入语句。

定义:对齐

对齐是指将原文的片段与其对应的译文片段进行匹配。

640?wx_fmt=png

Fig. 0.3: 法语单词“la”在输入序列中的对齐分布,主要集中在这四个单词上:“the”、“European”、“Economic”和“Area”。深紫色表示更好的注意力得分


有两种注意类型,使用所有编码器隐藏状态的注意力类型也称为“全局注意力”。相反,“局部注意力”只使用编码器隐藏状态的子集。由于本文的范围是全局attention,因此本文中提到的“attention”均被视为“全局attention”。

本文使用动画来描述注意力的工作原理,这样我们就能不需要数学符号来进行理解。例如,我将分享过去5年设计的4个NMT架构。我还将在这篇文章中加入一些关于一些概念的直觉,所以请注意它们!

1. Attention: 概要

在我们了解注意力是如何使用之前,请允许我与你分享使用seq2seq模型进行翻译任务背后的直觉。

直觉:seq2seq

翻译员从头到尾阅读德语文本。一旦完成,他开始逐字逐句地翻译成英语。如果这个句子非常长,他很可能已经忘记了他在前面读到的内容。

这是一个简单的seq2seq模型。我要进行的注意力层的逐步计算是一个seq2seq+attention模型。这里有一个关于这个模型的快速直觉。

直觉: seq2seq + attention

翻译员在从头到尾阅读德语文本的同时写下关键词,然后开始翻译成英语。在翻译每一个德语单词时,他会利用他写下的关键词。

注意力集中在不同的单词上,给每个单词打分。然后,使用softmax之后分数,我们使用编码器隐藏状态的加权和来聚合编码器隐藏状态,得到上下文向量。注意力层的实现可以分为4个步骤。

Step 0: 准备隐藏状态.

我们首先准备第一个解码器的隐藏状态(红色)和所有可用的编码器隐藏状态(绿色)。在我们的示例中,我们有4个编码器隐藏状态和当前解码器隐藏状态。

640?wx_fmt=gif

Fig. 1.0: 注意力的准备工作


Step 1: 得到每一个编码器隐藏状态的得分.

score(标量)由score函数(也称为alignment score函数alignment model)获得。在本例中,score函数是解码器和编码器隐藏状态之间的点积。

请看Appendix A中各种各样的score函数。

640?wx_fmt=gif

Fig. 1.1: 获得score


 

在上面的例子中,对于编码器的隐藏状态[5, 0, 1],我们获得了较高的注意分值60。这意味着下一个要翻译的单词将受到这种编码器隐藏状态的严重影响。

Step 2: 把所有的得分送到softmax层跑一下.

我们将分数送到softmax层中,这样softmax之后的分数(标量)加起来等于1。这些softmax的分数代表了注意力分布

640?wx_fmt=gif

Fig. 1.2: 获得softmax之后的得分


 

注意,基于softmax之后的分数score^,注意力的分配仅按预期放在了[5, 0, 1]上。实际上,这些数字不是二进制的,而是0到1之间的一个浮点数。

Step 3: 用每个编码器的隐藏状态乘以softmax之后的得分.

通过将每个编码器的隐藏状态与其softmax之后的分数(标量)相乘,我们得到对齐向量标注向量。这正是对齐产生的机制。

640?wx_fmt=gif

Fig. 1.3: 得到对齐的向量


 

在这里,我们看到除了[5, 0, 1]外,所有编码器隐藏状态的对齐都被降低到0,这是因为注意力得分较低。这意味着我们可以期望第一个被翻译的单词应该与输入单词使用[5, 0, 1]嵌入表示的单词匹配。

Step 4: 把所有对齐的向量加起来.

对对齐向量进行求和,生成上下文向量。上下文向量是前一步的对齐向量的聚合信息。

640?wx_fmt=gif

Fig. 1.4: 得到上下文向量


Step 5: 把上下文向量送到解码器中.

这取决于体系结构设计。稍后,我们将在第2a、2b和2c节的示例中看到架构如何使用上下文向量作为解码器。

640?wx_fmt=gif

Fig. 1.5: 把上下文向量送到解码器中


就是这些了,这就是完整的动画:

640?wx_fmt=gif


Fig. 1.6: 注意力机制


直觉: attention到底是怎么工作的?

答案:反向传播,意不意外。反向传播将尽一切努力确保输出接近基本事实。这是通过改变RNNs和score函数(如果有的话)中的权重来实现的。这些权重将影响编码器的隐藏状态和解码器的隐藏状态,从而影响注意力得分。

2. Attention: 例子

在前一节中,我们已经看到了seq2seq和seq2seq+attention体系结构。在下一小节中,让我们研究另外3个基于seq2seq的NMT架构,它们实现了注意力。为了完整起见,我还附加了他们的Bilingual Evaluation Understudy(BLEU)分数,这是一个评价生成的句子到参考句子的标准度量。

2a. Bahdanau et. al (2015)

Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et. al, 2015)

这种注意力的实现是注意力的创始者之一。作者在论文 “Neural Machine Translation by Learning to Jointly Align and Translate”的标题中使用了align这个词,意思是在训练模型的同时调整直接负责分数的权重。以下是关于这个架构需要注意的事项:

  1. 编码器是一个双向(前向+后向)门控循环单元(BiGRU)。解码器是一个GRU,其初始隐藏状态是由向后编码器GRU的最后一个隐藏状态修改而来的向量(下图中未显示)。

  2. 注意层中的score函数是additive/concat

  3. 下一个解码器时间步的输入是前一个解码器时间步(粉红色)的输出与当前时间步(深绿色)的上下文向量之间的拼接。

640?wx_fmt=png

Fig. 2a: NMT from 来自Bahdanau et. al. 的NMT,Encoder是BiGRU, decoder是GRU.


作者在WMT ' 14 English-to-French数据集上获得了26.75的BLEU评分。

直觉: 使用双向编码器+attention的seq2seq

翻译员A在阅读德语文本的同时写下关键词。翻译员B(他之所以担任高级职位,是因为他有一种额外的能力,可以把一个句子从后往前读)把同一个德语文本从最后一个单词读到第一个单词的同时记下关键词。这两个人定期讨论到目前为止他们读到的每一个单词。阅读完这篇德语文本后,翻译员B的任务是基于他们两人的讨论同时选择的关键字,将德语句子逐字逐句翻译成英语

翻译员A是前向RNN,翻译员B是后向RNN.

2b. Luong et. al (2015)

Effective Approaches to Attention-based Neural Machine Translation (Luong et. al, 2015)

"Effective Approaches to Attention-based Neural Machine Translation"一文的作者指出,简化和泛化Bahdanau et. al的体系结构非常重要。方法如下:

  1. 编码器是一个两层的长短时记忆(LSTM)网络。解码器也具有相同的结构,其初始隐藏状态是最后一个编码器的隐藏状态。

  2. 他们实验过的score函数是(i) additive/concat, (ii) 点积,(iii) location-based,和(iv) “general”

  3. 将当前解码器时间步的输出与当前时间步的上下文向量拼接,输入前向神经网络,得到当前解码器时间步的最终输出(粉红色)。

640?wx_fmt=png

Fig. 2b: 来自Luong et. al. 的NMT,Encoder是2层LSTM,decoder也一样.


在WMT ' 15 English-to-German数据集上获得了25.9的BLEU评分。

直觉: 使用2层编码器+attention的seq2seq

翻译员A在阅读德语文本的同时写下关键词。同样地,翻译员B(比译者A更资深)在阅读相同的德语文本时也会记下关键词。注意,初级翻译人员A必须向翻译人员B报告他们读到的每一个单词。一旦阅读完毕,他们就会根据所掌握的关键词,一个词一个词地把句子翻译成英语。

2c. Google的神经机器翻译(GNMT)

Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation (Wu et. al, 2016)

因为我们大多数人肯定都以这样或那样的方式使用过谷歌翻译,所以我觉得有必要谈谈谷歌在2016年实现的NMT。GNMT是前面两个例子的组合(深受第一个例子的启发)。

  1. 编码器由8个LSTMs组成,其中第一个LSTMs是双向的(其输出是拼接起来的),来自连续层(从第3层开始)的输出之间存在残差连接。解码器是8个单向LSTMs的“独立”堆叠。

  2. 使用的score函数是additive/concat,类似于第一个例子。

  3. 同样,就像第一个例子中一样,下一个解码器时间步的输入是前一个解码器时间步(粉红色)的输出和当前时间步(深绿色)的上下文向量之间的连接。

640?wx_fmt=png

Fig. 2c: 为Google翻译做的NMT。跳跃连接用曲线箭头表示。注意,LSTM单元格只显示隐藏状态和输入,它不显示单元格状态输入。


该模型在WMT ' 14 English-to-French上得到了38.95的BLEU评分,在WMT’14 English-to-German上得到了24.17的BLEU评分。

直觉: GNMT — 使用8个层叠的编码器(+bidirection+residual connections)+ attention的seq2seq

8名翻译员从下到上坐成一列,从翻译员A, B,…,H开始。每个翻译员都阅读相同的德语文本。对于每一个字,翻译员A和翻译员B共享他们之间的发现,然后翻译员B改善这个发现,再和翻译员C共享,重复这个过程,直到翻译员H。此外,阅读德语文本时,翻译员H基于他知道的东西和他收到的信息写下相关的关键字。

一旦所有人都读完了这篇德文文本,翻译员A被要求翻译第一个单词。首先,他试着回忆,然后他与翻译员B分享他的答案,翻译员B改进了答案并与翻译员C分享——重复这个过程,直到轮到翻译员H。翻译员H根据他所写的关键词和他得到的答案写出了第一个翻译单词。重复这一步骤,直到我们得到译文为止

3. 总结

下面是你在本文中看到的所有体系结构的快速总结:

  • seq2seq + attention

  • seq2seq with bidirectional encoder + attention

  • seq2seq with 2-stacked encoder + attention

  • GNMT — seq2seq with 8-stacked encoder (+bidirection+residual connections) + attention

就这么多了!在下一篇文章中,我将向你介绍自注意力的概念,以及它在谷歌的Transformer和自注意力生成对抗网络(SAGAN)中的应用。

Appendix: Score函数

下面是一些由Lilian Weng编辑的评分函数。在这篇文章中提到了Additive/concat和点积。score函数涉及点积运算(点积、余弦相似度等),其思想是度量两个向量之间的相似度。对于前馈神经网络评分函数,其思想是让模型在变换的同时学习对齐权值。

640?wx_fmt=png

Fig. A0: score函数的总结

640?wx_fmt=png

Fig. A1: score函数的总结
640?wx_fmt=png— END—

英文原文:https://towardsdatascience.com/attn-illustrated-attention-5ec4ad276ee3

640?wx_fmt=jpeg

请长按或扫描二维码关注本公众号

喜欢的话,请给我个好看吧640?wx_fmt=gif

这篇关于动画图解Attention机制,让你一看就明白的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1080849

相关文章

最好用的WPF加载动画功能

《最好用的WPF加载动画功能》当开发应用程序时,提供良好的用户体验(UX)是至关重要的,加载动画作为一种有效的沟通工具,它不仅能告知用户系统正在工作,还能够通过视觉上的吸引力来增强整体用户体验,本文给... 目录前言需求分析高级用法综合案例总结最后前言当开发应用程序时,提供良好的用户体验(UX)是至关重要

基于Python实现PDF动画翻页效果的阅读器

《基于Python实现PDF动画翻页效果的阅读器》在这篇博客中,我们将深入分析一个基于wxPython实现的PDF阅读器程序,该程序支持加载PDF文件并显示页面内容,同时支持页面切换动画效果,文中有详... 目录全部代码代码结构初始化 UI 界面加载 PDF 文件显示 PDF 页面页面切换动画运行效果总结主

Spring使用@Retryable实现自动重试机制

《Spring使用@Retryable实现自动重试机制》在微服务架构中,服务之间的调用可能会因为一些暂时性的错误而失败,例如网络波动、数据库连接超时或第三方服务不可用等,在本文中,我们将介绍如何在Sp... 目录引言1. 什么是 @Retryable?2. 如何在 Spring 中使用 @Retryable

Qt QWidget实现图片旋转动画

《QtQWidget实现图片旋转动画》这篇文章主要为大家详细介绍了如何使用了Qt和QWidget实现图片旋转动画效果,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、效果展示二、源码分享本例程通过QGraphicsView实现svg格式图片旋转。.hpjavascript

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

Flutter 进阶:绘制加载动画

绘制加载动画:由小圆组成的大圆 1. 定义 LoadingScreen 类2. 实现 _LoadingScreenState 类3. 定义 LoadingPainter 类4. 总结 实现加载动画 我们需要定义两个类:LoadingScreen 和 LoadingPainter。LoadingScreen 负责控制动画的状态,而 LoadingPainter 则负责绘制动画。

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,