MultiHeadAttention在Tensorflow中的实现原理

2024-04-22 04:52

本文主要是介绍MultiHeadAttention在Tensorflow中的实现原理,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


前言

通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。


一、MultiHeadAttention的本质内涵

1.Self_Atention机制

MultiHeadAttention是Self_Atention的多头堆嵌,有必要对Self_Atention机制进行一次深入浅出的理解,这也是MultiHeadAttention的核心所在。

Self_Attention并不直接使用输入向量,而是先将其进行映射,使得输入向量在每个位置上产生一个query和context,context充当字典。在context的每个位置都提供一个key和value向量。

query:尝试去获取某类信息的序列。

context:包含key序列和value序列,是query感兴趣的内容。

最终输出的形状将与query序列相同。

一个常见的类比是,这种操作就像字典查询。一个模糊的、可区分的、矢量的字典查询。

如下是一个普通的 python 字典类型数据,有 3 个键和 3 个值,并被传递给一个query——"What color is it ?"。这个query会与key="color"最契合,最终得到查询结果value="blue"

query是你要尝试去找的东西。key表示字典里有哪些信息,而value就是这些信息。当你在正则字典中查找一个query时,字典会找到匹配的key,并返回其相关的value。这个查询要么有一个匹配的键,要么没有。你可以想象一个模糊的字典,其中的键不一定要完全匹配。如果你在上面的字典中查找 query—"What species is it ?",也许你希望它返回 key="type",value="pickup",因为那是与query最匹配的key和value。

注意力层就像这样做了一个模糊查找,但它不仅仅是在寻找最好的key,而是根据query与每个key的匹配程度来组合这些value。

那是如何做到这一点的呢?在注意力层中,query、key和value都是向量。注意力层不是简单地做哈希查找,而是结合query和key向量来确定它们的匹配程度——计算query和key的向量点积,再将所有匹配程度经过Softmax映射完后,即得到 "注意力得分"。最终该层返回所有value的加权平均值,以 "注意力分数 "为权重。

对于一段具体的文本来说,每一个字都会引发一个疑问query,并提供一个关键值key和一个目标内容value。每个query都会去寻找感兴趣的key,并按注意力分数提取并组合value,

图中越粗的红线对应的attention权重更大,query与key的紧密程度也越近。attention权重如此分布也是很符合情理的,要想回答query =“他是谁?”,我们很大可能会在“是”后面寻找答案,因为“爱人”提供的信息最多,所以它俩的attention权重最大。

总的来说,Self_Attention模拟的是一个符合人脑思维逻辑的研究过程。每当遇到一些新的信息,我们总会产生一定量的疑问(query),为了解决疑问,我们需要在信息中捕捉关键字(key),进而凝练出该关键字中所蕴涵的答案(value)。特定的疑问(query)需要联系特定的关键字(key),进而得出最终答案,这个最终答案往往是折合了不同value而得来的。

2.MultiHead_Atention机制

在不同情景中,字引发的query是不同的,例如,

“他是男的,已婚。”

query可以是”他的性别是什么?”,或者”他结婚了吗?”。不同的query会产生不同的注意力分数。单一的Self_attention无法捕捉多层面query和key之间的依赖关系,因为它只进行一次attention的分配。意在解决此类局限性,MultiHead_Atention会计算多次Self_attention。

利用MultiHead_Atention机制,可以为每一个输入学习到一个信息量丰富的向量表示。

二、使MultiHeadAttention在TensorFlow中的代码实现

1.参数说明

TensorFlow中是用tf.keras.layers.MultiHeadAttention()实现的。它的参数分为两类,一种参数为初始化参数,存在于__init__方法中;另一种为调用参数,存在于call方法中。

主要的初始化参数:

num_heads:Self_Attention的层数

key_dim:query和key多头映射层的输出shape在axis=-1上的长度。因为后续需要计算query和key的点积,所以需要保证query和key在最后一个轴上的长度相等。

value_dim:value多头映射层的输出shape在axis=-1上的长度。如果不指定,则默认等于key_dim

output_shape:  指定输入经过整个MultiHeadAttention层后的输出shape,默认与进入query多头映射层的输入shape相同

主要的call方法参数:

'''  B即Batch_size,每一批中的样本数;

    T是query的个数,即一段序列产生的疑问个数;S是value和key的个数,即一段序列产生的关键字和关键信息的个数,序列产生的key和value是成对出现的,所以value映射层 和key映射层的输入张量在axis=1处的长度S相同。T和S是可以随意指定的,只需在样本集进入Embedding层之前,先通过一个dense层进行T和S的指定(T和S等于各自dense层中的神经元个数)。例如,文本集shape=(B, S),经过一个具有T个神经元的Dense层→shape=(B, T),再经Embedding层→shape=(B, T, dim),得到query映射层的输入张量。当然,如果不愿如此麻烦,可直接将经Embedding层得到shape=(B, S, dim)的张量作为query映射层的输入;

    dim通常是Embedding向量的长度(每个字对应一个Embedding向量)'''

query:输入query多头映射层且shape为(B, T, dim)的张量

value:输入value多头映射层且shape为(B, S, dim)的张量

key:输入key多头映射层且shape为(B, S, dim)的张量,如果未指定,则key=value

use_causal_mask:布尔值,是否开启causal_mask(因果掩码)机制

2.整体结构

tf.keras.layers.MultiHeadAttention类中call()方法的逻辑过程就是MultiHeadAttention的前向传播过程,我将其提炼成以下三部分,

        ''' 多头映射层 '''query = self._query_dense(query)key = self._key_dense(key)value = self._value_dense(value)''' 注意力层 '''attention_output, attention_scores = self._compute_attention(query, key, value, attention_mask, training)''' 输出层 '''attention_output = self._output_dense(attention_output)

3.多头映射层

由query多头映射层—query_dense,value多头映射层—value_dense,key多头映射层—key_dense组成。

每个映射层执行的张量运算是一样的,张量运算逻辑为,

                                                   ' abc , cde -> abde '               



该层的训练参数总数为,

4.注意力层

计算query与key之间的内积,张量运算逻辑为,

                                              ' aecd, abcd -> acbe '

内积能够反映向量之间的相关程度,内积结果越大则相关性越大,联系也越紧密。得到query和key的内积后,为了得到attention分数,需要将内积结果进行softmax映射。

sttention_scores张量可视作一个B行num_heads列的矩阵,其矩阵中的元素均是T行S列的注意力分数矩阵。当输入是大序列(比如音频序列)时,TransFormer需要维护的注意力分数矩阵将呈n^{2}曲线式增长,这种庞大的数据量将会对TransFormer训练和推理的效率和速度产生严重的影响,在内存上的要求也会成n^{2}曲线式增长。


最后利用attention分数对value进行加权叠加,张量运算逻辑为,

                                                        'acbe,aecd->abcd' 

  



注意:如果指定use_causal_mask=True引入Causal_Mask(因果掩码)机制,则在softmax映射时,会传入一个左下三角为True右上三角为False的,布尔类型的,且与attention_scores.shape相同掩码张量,此时掩码张量中为False的对应位置(对应attention内积张量)将会被softmax忽略。如此一来就会导致每个query只会与当前及其以前的key进行内积,并不会考虑未来的key。进而导致在每个query处产生的新value只会是当前value与过往value在sttention分数上的加权叠加。这样的结构是因果的,符合在预测中结果会对输入产生影响的因果逻辑。因果掩码会在Decoder中使用。


注意力层无可训练的参数。

5.输出映射层

属于MultiHeadAttention的最后一层,负责将注意力层得到的value在sttention分数上的加权叠加后的张量进行输出映射。张量运算逻辑为,

                                                       ' abcd, cde -> abe '



该层训练参数总共为,


验证

import tensorflow as tflayer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,return_attention_scores=True)''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)

这篇关于MultiHeadAttention在Tensorflow中的实现原理的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/924855

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、