本文主要是介绍机器学习周报(8.26-9.1),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 摘要
- Abstract
- self-attetion
- QKV理解
- 如何让self-attention更有效
- local attention/truncated attention方法
- stride attention方法
- Global Attention方法
- data driving方法
- Clustering
- sinkhorn sorting network
- 选取representative keys
- 减少Keys数量的方法
- self-attention
- Synthesizer
- 总结
摘要
本周先是好好理解了一下self-attention的QKV的理解,关于如何让自注意力机制更有效的问题,学习了self-attention的多种变形,包括减少注意力矩阵的计算量、加快注意力机制的运算速度、去掉attention等。
Abstract
This week, I first had a good understanding of the QKV of self-attention, about how to make the self-attention mechanism more effective, and learned a variety of variants of self-attention, including reducing the computation amount of attention matrix, speeding up the computation speed of attention mechanism, removing attention and so on.
self-attetion
QKV理解
以搜索查询商品为例:
query可以理解为输入要查询的商品;
key为商品的标签或者title;
value可理解为商品的评价之类的;
相似度=querykey(矩阵乘法) 根据相似度 召回
总分=相似度value 根据总分排序输出
Q(query):模型从token中提取出来的对token的理解信息,用于主动与其他token计算相似程度
K(key):模型从token提取出来的,与其他token的关系信息,被用于与其他token计算相似程度
V(value):表示当前token的重要程度
-
self-attention中self的理解
self-attention的self,表示query,key,value都来自自己,每个token都能提取出来自己的query,key,value -
计算过程
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k V ) Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}}V) Attention(Q,K,V)=softmax(dkQKTV)
使用具体例子看一下计算过程:
如何让self-attention更有效
自制力机制里面的具体计算过程,如果对目前大多数方法进行抽象的话,可以将其归纳为两个过程:第一个过程是根据Query和Key计算权重系数,第二个过程根据权重系数对Value进行加权求和。
第一个过程中,输入一个向量,可通过乘以不同的矩阵得到一个query和一个key的向量序列,长度都和输入序列一样(假设为N)。由query和key 两个序列做点积就可以得到attention matrix,这个运算量是NN级的。这种方式最大的问题就是当序列长度太长的时候,对应的 Attention Matrix 维度太大,会给计算带来麻烦。当N很小的时候,运算量放在整个网络里面可以忽略不计,但当N很大的时候,self-attention就有可能主导整个网络的运算量,这时优化self-attention的计算就可以得到显著的影响,这样我们加快self attention 才会对神经网络有帮助。
local attention/truncated attention方法
例如只看前后两个位置的时候,那么与其他位置的值就可以直接设置为0,例如图中灰色的位置。但是这个这个明显有问题,我们在做attention的时候只能看到小范围的数值,那这个就跟CNN非常相似了,local attention是可以加快我们的attention的方法,但是不一定能得到很好的结果。
stride attention方法
上面是看前后一步的位置,这样只能看到局部的信息,而stride attention可以看指定步长的邻居,因此可以考虑范围相对广一些,下图的例子考虑间隔两格的邻居,步长设置为2,根据实际问题需要可以设置不同的步长。
Global Attention方法
如果需要考虑所有的输入,又不想计算量太大,就可以用到global attention。核心思想是加入一个特殊token到原始的sequence里面,在global attention,每个特殊的token都加入每一个token,收集全局信息。每个特殊的token都被其他所有的token加入,以用来获取全局信息。
Longformer 就是组合了上面的三种 attention
Big Bird 就是在 Longformer 基础上随机选择 attention 赋值,进一步提高计算效率。
data driving方法
在一个self-attention里面的矩阵里面,某些位置有很大的值,有些位置又有很小的值,那我们是否可以把很小的值变为0,那我们是否能估计矩阵哪里有大值,哪里有小值吗?这个方法叫做clustering。
Clustering
- 我们先把query和key取出来,然后根据query和key的相近程度做clustering。对于相近的数据就放在一起,对于比较远的数据就属于不同的cluster。
下面我们有四个cluster,用不同的颜色来标出。
对于query和key形成的attention matrix来说,只有当query和key的cluster属于同一个的时候,我们才计算他们的attention weight。对于不属于同一个cluster的两个query和key,就把他们设为0。这种方法可以加速我们的运算,这是一种基于数据来决定的!
sinkhorn sorting network
上面的方法是通过人为决定attention matrix 里面哪些位置不需要计算。而在sinkhorn sorting network里面,机器自己直接学习另外一个network来决定怎么输出这个矩阵。
我们把输入的序列,经过一个NN之后产生另外一排向量序列,生成一个N×N的的矩阵。我们要把这个生成的不是二进制的矩阵变成我们的attention matrix。这个过程是不用经过二进制变换的,可以直接输出attention matrix。
我们并不需要一个full attention matrix,因为在一个attention matrix里会有很多冗余的列,很多列都是重复的,因此可以去掉冗余的列,缩小attention matrix,加快attention的速度呢。简化attention matrix的方法:减少计算attention的key的数量。
选取representative keys
假设有N个key,从中选取K个代表的key。然后与N个query序列相乘得到一个N×K的矩阵,然后从N个value,也选取K个代表value。然后我们把这K个value和attention matrix做weight sum加权和,就得到attention matrix layer的输出。
为什么选择代表key,而不选择代表query呢?
因为在self-attention里面输入和输出长度一致,如果改变了query的长度那么就改变了输出的长度,如果是输入一个序列输出一个数值的模型就可以选择代表query。
减少Keys数量的方法
- 用CNN来扫过输入的key序列,得到一个更短的序列,那这个就是代表性的key。
- 输入的key序列可以看成是一个d×N的矩阵,由线性代数知识可知,将一个k×N的矩阵乘上一个N×K的矩阵,然后就得到了d*K的矩阵。那这个得到的新矩阵就是代表性key序列。
self-attention
输入的向量I分别通过变换矩阵 W q , W k , W v W^q,W^k,W^v Wq,Wk,Wv得到Q,K,V矩阵
忽略softmax
下面这两种计算方式中,得到的结果是相同的,但是两者的计算速度相差甚远
- 第一个计算方法中, K T 和 Q K^T和Q KT和Q相乘的乘法次数为N×d×N,得到A(attention matrix),通过softmax得到 A ′ A' A′, V 与 A ′ V与A' V与A′的乘法次数为d×N×N,所以送的计算次数为: ( d + d ′ ) N 2 (d+d')N^2 (d+d′)N2
- 第二个计算方法中,总的计算次数为: 2 d ′ d N 2d'dN 2d′dN
- 加上softmax的计算过程
将上述 b b b的计算公式进行简化
由下图可以看出蓝色的 vector 和黄色的 vector 其实跟 b1 中的 1 是没有关系的。
也就是说,当我们算 b2、b3… 的时候,蓝色的 vector 和黄色的 vector 不需要再重复计算,大大减少了重复的计算量。
Synthesizer
总结
本周主要是复习了self-attention的基本原理的前提下,学习了对self-attention的一下更有效的方法,然后有些公式推导理解还不够透彻,我会继续研究推导理解
这篇关于机器学习周报(8.26-9.1)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!