本文主要是介绍FlashAttention之我见,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Attention机制可以算是Transformer的灵魂。正因为有了attention,模型的效果才能大幅提升。但同样是因为attention,导致transformer很难处理超长上下文,因为attention占用显存的大小与上下文长度的平方成正比,会导致上下文很长时显存爆炸。FlashAttention正是为了解决显存爆炸而设计的,它不光解决了显存爆炸的问题,同时也加速了attention的计算,并从数学上保证了结果的一致性。
1. Self Attention原理
Attention的计算涉及三个矩阵:Q、K、V,这三个矩阵在送给attention计算时都有相同的维度。我们先不考虑multi-head attention,只考虑one head的self attention。初始时,这三个矩阵的维度均为N x d,N即为上下文的长度(当前的大模型普遍支持的N上限为128K,谷歌也有大模型可以到1M)。通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(dQKT)V
我们将attention的运算拆分,首先是 Q Q Q和 K T K^T KT的矩阵乘,它会生成一个N x N的矩阵,算法复杂度为 O ( N 2 d ) O(N^2d) O(N2d)。接着对N x N矩阵每一个元素进行缩放,并对每一行求softmax,这个过程不会改变矩阵的维度,算法复杂度为 O ( N 2 ) O(N^2) O(N2)。最后再和矩阵V相乘得到结果O,最终的矩阵维度N x d,和输入的三个矩阵的维度保持一致,算法复杂度同样为 O ( N 2 d ) O(N^2d) O(N2d),所以总的算法复杂度也为 O ( N 2 d ) O(N^2d) O(N2d)。attention运算会保持输入和输出矩阵维度的一致性,但是在具体的实现过程中,我们还是不可避免的要产生一个N x N的矩阵,这个矩阵在超长上下文下的显存占用非常可观。按照当前大模型普遍的上下文长度要求128K来算,N x N矩阵的大小为128Kx128K=16G,如果矩阵用fp16来存储,那单单这一个矩阵就要占32G显存,真的是显存占用不可限量!比较令人郁闷的是,我们最终计算的结果是N x d的矩阵,一般情况下d都要远远小于N(假设d为4K,最终的矩阵大小也仅为512M),但是为了保证结果的正确性我们不得不生成一个N x N的大矩阵。这也是为什么早期的大模型支持的上下文长度普遍较短的原因–attention占用的显存太大。
2. Multi-head Attention原理
在真正使用attention的时候,我们往往采用multi-head attention。Multi-head attention的计算公式和self attention基本一致,它改变了 Q 、 K 、 V Q、K、V Q、K、V每一行的定义:将维度d的向量分成h组变成一个 h ∗ d k h * d_k h∗dk的矩阵, Q 、 K 、 V Q、K、V Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵(不考虑batch维)。分别将 Q 、 K 、 V Q、K、V Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的 Q 、 K 、 V Q、K、V Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。算法复杂度为 O ( h N 2 d k ) = O ( N 3 ) O(hN^2d_k)=O(N^3) O(hN2dk)=O(N3)。与self attention相比,multi-head attention的算法复杂度由 O ( N 2 d ) O(N^2d) O(N2d)变为 O ( N 3 ) O(N^3) O(N3),在长上下文的情况下计算量增大了很多。不仅计算量增大很多,在中间运算的过程中会产生h个N x N的矩阵的矩阵,所以显存占用也是大了h倍。这也说明如果不优化显存占用,multi-head attention的上下文长度无法扩展到很大。
3. FlashAttention原理
正当大家通过购买大显存显卡或者采用稀疏注意力来增大上下文长度时,22年斯坦福的一个博士吹岛(Tri Dao)发了一篇论文《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》,FlashAttention应运而生。FlashAttention不光解决了multi-head attention在计算过程中显存占用过大的问题,将 O ( N 2 ) O(N^2) O(N2)级的显存占用优化到了 O ( N d ) O(Nd) O(Nd),而且它的计算结果还是精确的,不是近似注意力,更重要的是attention的计算速度也变快了很多。从这篇论文开始,大模型支持的上下文长度普遍由16K上升到128K,甚至更长。
FlashAttention解决显存占用过大采用的方法就是分块(Tiling),将完整 Q K T QK^T QKT的计算分成一个一个小块来实现。因为块足够小,小到可以直接在共享内存(shared memory)中放下,从而加快了显存的访问,进而也加快了计算速度。分块不是特别新鲜的技术,矩阵乘的加速就利用了分块的技巧,但是self attention的计算不止 Q K T QK^T QKT,还包括对该矩阵求softmax。FlashAttention最大的创新点就来自于分块计算的同时还保证了softmax计算的正确性。我们知道,对一个向量求softmax,我们必须得获得完整的向量才可以,显然分块破坏了这个前提。理论上,没有获得完整的向量前我们是无法计算softmax的,但是我们可以在分块的过程中迭代计算softmax,原理也不是那么复杂,下面给出解释。
3.1 softmax的分块计算
首先,在实际计算softmax的时候,为了避免指数运算的数值溢出,往往会利用safe softmax求softmax,这两个是等价的,只不过多了一步求max的过程。给定一个向量 X = [ x 1 , x 2 , . . . , x d ] X=[x_1,x_2,...,x_d] X=[x1,x2,...,xd],我们先求得向量的最大值
m ( x ) = max i ( x i ) \begin{equation} m(x)=\max_{i} (x_i) \end{equation} m(x)=imax(xi)
,有了最大值之后我们就可以利用公式:
s o f t m a x ( x i ) = e x i − m ( x ) ∑ j = 1 d e x j − m ( x ) \begin{equation} softmax(x_i)=\frac {e^{x_i-m(x)}} {\sum_{j=1}^d e^{x_j-m(x)}} \end{equation} softmax(xi)=∑j=1dexj−m(x)exi−m(x)
求softmax。每一项减最大值再求exp就可以保证每一项的结果都在(0,1]之间,从而避免了exp过大导致溢出。
现在我们将向量 X X X分成 T c T_c Tc段,每段长度 B c B_c Bc,每一段分别去求safe softmax。从公式(2)可以看出,如果不做特殊处理,每一段计算的softmax值肯定是错的,因为softmax的计算要依赖一个全局的max值,但是每一段获取的只是一个局部max值,并且分母求和也只考虑了当前段的和没有考虑所有段的和。所以为了获得正确的softmax值,我们就需要让分子和分母都按照正确的定义去计算。FlashAttention通过在迭代计算每一段的safe softmax过程中维护三个变量 m , f , ℓ m,f,ℓ m,f,ℓ使得迭代结束之后就可以获得正确的softmax值。
m m m的含义是每一段向量的最大值,也即
m i ( X ) = max i ∗ B c ≤ j < ( i + 1 ) ∗ B c ( x j ) \begin{equation} m_i(X)=\max_{i*B_c \le j \lt (i+1)*B_c} (x_j) \end{equation} mi(X)=i∗Bc≤j<(i+1)∗Bcmax(xj)
m的长度为 T c T_c Tc。在实际的代码中,FlashAttention还会维护一个到当前段为止的全局最大值,我们记为 M , M M,M M,M的更新公式为
M i = max ( M i − 1 , m i ) \begin{equation} M_i=\max (M_{i-1},m_i) \end{equation} Mi=max(Mi−1,mi)
f的含义是每一段每一个元素利用当前段计算的最大值求 e x p ( x − m ) exp(x-m) exp(x−m),也即
f i j ( X ) = e x j − m i ( 0 < i < T c , i ∗ B c ≤ j < ( i + 1 ) ∗ B c ) \begin{equation} f_{ij}(X)= e^{x_j-m_i} (0<i<T_c, i*B_c \le j <(i+1)*B_c) \end{equation} fij(X)=exj−mi(0<i<Tc,i∗Bc≤j<(i+1)∗Bc)
上述公式的含义是第i段中,每一个元素都减去段最大值 m i m_i mi并求指数, f f f的长度为d。正如前面所说,这样的计算方法是错误的,我们需要进行修正。假设我们当前遍历到第i段,这i段的 f f f值已经求出,根据公式(5)我们知道每一段的 f f f值减的都是段内局部最大值 m i m_i mi,正确的计算应该是减去前i段的全局最大值 M i M_i Mi,我们可以根据变量 m m m和 M M M来更新每一段的 f f f,方法就是每一段的 f f f值乘以 e m i − M i e^{m_i-M_i} emi−Mi,然后利用指数运算的公式我们可以得到:
f i j ( X ) ∗ e m i − M i = e x j − m i ∗ e m i − M i = e x j − m i + m i − M i = e x j − M i \begin{equation} f_{ij}(X) *e^{m_i-M_i}=e^{x_j-m_i}*e^{m_i-M_i}=e^{x_j-m_i+m_i-M_i}=e^{x_j-M_i} \end{equation} fij(X)∗emi−Mi=exj−mi∗emi−Mi=exj−mi+mi−Mi=exj−Mi
如此我们就得到遍历完i段时正确的f值。当然,每一段我们都需要更新前面所有的f值,总的复杂度为 O ( d 2 ) O(d^2) O(d2)。但是f的计算更多是概念性的,在真实的代码中是不可能按照如此高的复杂度去计算所有的 f f f,原理在于我们最终计算完softmax之后还要和 V V V相乘,我们在迭代计算的过程中进行更新即可,此时复杂度为 O ( d ) O(d) O(d)。
从上面的介绍可以看出,为了获得safe softmax中正确的分子值,我们可以在分段计算的过程中,将之前的结果乘以 e m i − M i e^{m_i-M_i} emi−Mi进行矫正,所以 e m i − M i e^{m_i-M_i} emi−Mi可以看做是矫正因子。同理,分母值的更新也可以通过矫正因子实现。
ℓ ℓ ℓ表示每一段 f f f的累加值。也即
ℓ i ( X ) = ∑ j = i ∗ B c ( i + 1 ) ∗ B c − 1 f i j ( X ) \begin{equation} ℓ_{i}(X)= \sum_{j=i*B_c}^{(i+1)*B_c-1}f_{ij}(X) \end{equation} ℓi(X)=j=i∗Bc∑(i+1)∗Bc−1fij(X)
ℓ ℓ ℓ的长度为 T c T_c Tc。和f的计算类似,我们也可以通过将每一个 ℓ ℓ ℓ乘以 e m i − M i e^{m_i-M_i} emi−Mi将求和的结果进行矫正。对 ℓ ℓ ℓ求和就可以获得到当前段的所有元素求exp的和,也即safe softmax的正确分母值,我们记为 L L L。在更新 L L L的时候我们按照下述公式
这篇关于FlashAttention之我见的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!