FlashAttention之我见

2024-09-04 20:44
文章标签 flashattention 之我见

本文主要是介绍FlashAttention之我见,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Attention机制可以算是Transformer的灵魂。正因为有了attention,模型的效果才能大幅提升。但同样是因为attention,导致transformer很难处理超长上下文,因为attention占用显存的大小与上下文长度的平方成正比,会导致上下文很长时显存爆炸。FlashAttention正是为了解决显存爆炸而设计的,它不光解决了显存爆炸的问题,同时也加速了attention的计算,并从数学上保证了结果的一致性。

1. Self Attention原理

Attention的计算涉及三个矩阵:Q、K、V,这三个矩阵在送给attention计算时都有相同的维度。我们先不考虑multi-head attention,只考虑one head的self attention。初始时,这三个矩阵的维度均为N x d,N即为上下文的长度(当前的大模型普遍支持的N上限为128K,谷歌也有大模型可以到1M)。通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(d QKT)V
我们将attention的运算拆分,首先是 Q Q Q K T K^T KT的矩阵乘,它会生成一个N x N的矩阵,算法复杂度为 O ( N 2 d ) O(N^2d) O(N2d)。接着对N x N矩阵每一个元素进行缩放,并对每一行求softmax,这个过程不会改变矩阵的维度,算法复杂度为 O ( N 2 ) O(N^2) O(N2)。最后再和矩阵V相乘得到结果O,最终的矩阵维度N x d,和输入的三个矩阵的维度保持一致,算法复杂度同样为 O ( N 2 d ) O(N^2d) O(N2d),所以总的算法复杂度也为 O ( N 2 d ) O(N^2d) O(N2d)。attention运算会保持输入和输出矩阵维度的一致性,但是在具体的实现过程中,我们还是不可避免的要产生一个N x N的矩阵,这个矩阵在超长上下文下的显存占用非常可观。按照当前大模型普遍的上下文长度要求128K来算,N x N矩阵的大小为128Kx128K=16G,如果矩阵用fp16来存储,那单单这一个矩阵就要占32G显存,真的是显存占用不可限量!比较令人郁闷的是,我们最终计算的结果是N x d的矩阵,一般情况下d都要远远小于N(假设d为4K,最终的矩阵大小也仅为512M),但是为了保证结果的正确性我们不得不生成一个N x N的大矩阵。这也是为什么早期的大模型支持的上下文长度普遍较短的原因–attention占用的显存太大。

2. Multi-head Attention原理

在真正使用attention的时候,我们往往采用multi-head attention。Multi-head attention的计算公式和self attention基本一致,它改变了 Q 、 K 、 V Q、K、V QKV每一行的定义:将维度d的向量分成h组变成一个 h ∗ d k h * d_k hdk的矩阵, Q 、 K 、 V Q、K、V QKV此时成为了 N ∗ h ∗ d k N * h * d_k Nhdk的三维矩阵(不考虑batch维)。分别将 Q 、 K 、 V Q、K、V QKV的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k hNdk的三维矩阵。此时的三个矩阵就是具有h个头的 Q 、 K 、 V Q、K、V QKV,我们就可以按照self attention的定义计算h个头的attention值。算法复杂度为 O ( h N 2 d k ) = O ( N 3 ) O(hN^2d_k)=O(N^3) O(hN2dk)=O(N3)。与self attention相比,multi-head attention的算法复杂度由 O ( N 2 d ) O(N^2d) O(N2d)变为 O ( N 3 ) O(N^3) O(N3),在长上下文的情况下计算量增大了很多。不仅计算量增大很多,在中间运算的过程中会产生h个N x N的矩阵的矩阵,所以显存占用也是大了h倍。这也说明如果不优化显存占用,multi-head attention的上下文长度无法扩展到很大。

3. FlashAttention原理

正当大家通过购买大显存显卡或者采用稀疏注意力来增大上下文长度时,22年斯坦福的一个博士吹岛(Tri Dao)发了一篇论文《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》,FlashAttention应运而生。FlashAttention不光解决了multi-head attention在计算过程中显存占用过大的问题,将 O ( N 2 ) O(N^2) O(N2)级的显存占用优化到了 O ( N d ) O(Nd) O(Nd),而且它的计算结果还是精确的,不是近似注意力,更重要的是attention的计算速度也变快了很多。从这篇论文开始,大模型支持的上下文长度普遍由16K上升到128K,甚至更长。
FlashAttention解决显存占用过大采用的方法就是分块(Tiling),将完整 Q K T QK^T QKT的计算分成一个一个小块来实现。因为块足够小,小到可以直接在共享内存(shared memory)中放下,从而加快了显存的访问,进而也加快了计算速度。分块不是特别新鲜的技术,矩阵乘的加速就利用了分块的技巧,但是self attention的计算不止 Q K T QK^T QKT,还包括对该矩阵求softmax。FlashAttention最大的创新点就来自于分块计算的同时还保证了softmax计算的正确性。我们知道,对一个向量求softmax,我们必须得获得完整的向量才可以,显然分块破坏了这个前提。理论上,没有获得完整的向量前我们是无法计算softmax的,但是我们可以在分块的过程中迭代计算softmax,原理也不是那么复杂,下面给出解释。

3.1 softmax的分块计算

首先,在实际计算softmax的时候,为了避免指数运算的数值溢出,往往会利用safe softmax求softmax,这两个是等价的,只不过多了一步求max的过程。给定一个向量 X = [ x 1 , x 2 , . . . , x d ] X=[x_1,x_2,...,x_d] X=[x1,x2,...,xd],我们先求得向量的最大值
m ( x ) = max ⁡ i ( x i ) \begin{equation} m(x)=\max_{i} (x_i) \end{equation} m(x)=imax(xi)
,有了最大值之后我们就可以利用公式:
s o f t m a x ( x i ) = e x i − m ( x ) ∑ j = 1 d e x j − m ( x ) \begin{equation} softmax(x_i)=\frac {e^{x_i-m(x)}} {\sum_{j=1}^d e^{x_j-m(x)}} \end{equation} softmax(xi)=j=1dexjm(x)exim(x)
求softmax。每一项减最大值再求exp就可以保证每一项的结果都在(0,1]之间,从而避免了exp过大导致溢出。
现在我们将向量 X X X分成 T c T_c Tc段,每段长度 B c B_c Bc,每一段分别去求safe softmax。从公式(2)可以看出,如果不做特殊处理,每一段计算的softmax值肯定是错的,因为softmax的计算要依赖一个全局的max值,但是每一段获取的只是一个局部max值,并且分母求和也只考虑了当前段的和没有考虑所有段的和。所以为了获得正确的softmax值,我们就需要让分子和分母都按照正确的定义去计算。FlashAttention通过在迭代计算每一段的safe softmax过程中维护三个变量 m , f , ℓ m,f,ℓ mf使得迭代结束之后就可以获得正确的softmax值。
m m m的含义是每一段向量的最大值,也即
m i ( X ) = max ⁡ i ∗ B c ≤ j < ( i + 1 ) ∗ B c ( x j ) \begin{equation} m_i(X)=\max_{i*B_c \le j \lt (i+1)*B_c} (x_j) \end{equation} mi(X)=iBcj<(i+1)Bcmax(xj)
m的长度为 T c T_c Tc。在实际的代码中,FlashAttention还会维护一个到当前段为止的全局最大值,我们记为 M , M M,M MM的更新公式为
M i = max ⁡ ( M i − 1 , m i ) \begin{equation} M_i=\max (M_{i-1},m_i) \end{equation} Mi=max(Mi1,mi)
f的含义是每一段每一个元素利用当前段计算的最大值求 e x p ( x − m ) exp(x-m) exp(xm),也即
f i j ( X ) = e x j − m i ( 0 < i < T c , i ∗ B c ≤ j < ( i + 1 ) ∗ B c ) \begin{equation} f_{ij}(X)= e^{x_j-m_i} (0<i<T_c, i*B_c \le j <(i+1)*B_c) \end{equation} fij(X)=exjmi(0<i<Tc,iBcj<(i+1)Bc)
上述公式的含义是第i段中,每一个元素都减去段最大值 m i m_i mi并求指数, f f f的长度为d。正如前面所说,这样的计算方法是错误的,我们需要进行修正。假设我们当前遍历到第i段,这i段的 f f f值已经求出,根据公式(5)我们知道每一段的 f f f值减的都是段内局部最大值 m i m_i mi,正确的计算应该是减去前i段的全局最大值 M i M_i Mi,我们可以根据变量 m m m M M M来更新每一段的 f f f,方法就是每一段的 f f f值乘以 e m i − M i e^{m_i-M_i} emiMi,然后利用指数运算的公式我们可以得到:
f i j ( X ) ∗ e m i − M i = e x j − m i ∗ e m i − M i = e x j − m i + m i − M i = e x j − M i \begin{equation} f_{ij}(X) *e^{m_i-M_i}=e^{x_j-m_i}*e^{m_i-M_i}=e^{x_j-m_i+m_i-M_i}=e^{x_j-M_i} \end{equation} fij(X)emiMi=exjmiemiMi=exjmi+miMi=exjMi
如此我们就得到遍历完i段时正确的f值。当然,每一段我们都需要更新前面所有的f值,总的复杂度为 O ( d 2 ) O(d^2) O(d2)。但是f的计算更多是概念性的,在真实的代码中是不可能按照如此高的复杂度去计算所有的 f f f,原理在于我们最终计算完softmax之后还要和 V V V相乘,我们在迭代计算的过程中进行更新即可,此时复杂度为 O ( d ) O(d) O(d)
从上面的介绍可以看出,为了获得safe softmax中正确的分子值,我们可以在分段计算的过程中,将之前的结果乘以 e m i − M i e^{m_i-M_i} emiMi进行矫正,所以 e m i − M i e^{m_i-M_i} emiMi可以看做是矫正因子。同理,分母值的更新也可以通过矫正因子实现。
ℓ ℓ 表示每一段 f f f的累加值。也即
ℓ i ( X ) = ∑ j = i ∗ B c ( i + 1 ) ∗ B c − 1 f i j ( X ) \begin{equation} ℓ_{i}(X)= \sum_{j=i*B_c}^{(i+1)*B_c-1}f_{ij}(X) \end{equation} i(X)=j=iBc(i+1)Bc1fij(X)
ℓ ℓ 的长度为 T c T_c Tc。和f的计算类似,我们也可以通过将每一个 ℓ ℓ 乘以 e m i − M i e^{m_i-M_i} emiMi将求和的结果进行矫正。对 ℓ ℓ 求和就可以获得到当前段的所有元素求exp的和,也即safe softmax的正确分母值,我们记为 L L L。在更新 L L L的时候我们按照下述公式

这篇关于FlashAttention之我见的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1136979

相关文章

KMP算法之我见(初解)

几天前看到KMP算法的时候,头大如麻,略读一遍,决定跳过,学完了整章串、数组、矩阵和广义表之后回头专心研究KMP算法。 在学习这本《数据结构》的前几章的时候我就开始对这本教程有点失望了,当初在图书馆里对比了十几本教程选择了它,主要原因是它图解较多,便于理解,但是细读发现它的代码不够讲究,实用性不强,可能是我经验匮乏吧,反正这本教材的堆栈部分我很不满意,代码可用性太小,相较其他版本的结构体实现堆栈

FlashAttention-2 论文阅读笔记

FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本节详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。 算法 FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理m

android学习之我见

首先说说Android开发环境的搭建吧,本来这是一件很容易的事情,但是很多同学依旧会遇到很多的问题,建议就是在网站下一个Android开发环境的教程,然后一步一步去搭建。同学一遇到有问题首先是要自己去看看哪里出问题了,如果找不到问题,应该是想到搜索引擎才对的,谷歌肯定会给你最好的解答的。     一、Android开发起步其次说的是刚刚起步学习的同学,假如开始没有任何的开发经验的话,千万不要着

监控之我见

我们想像中的监控? 我们想像中监控无所不能,是个超人。需要什么数据,它就能给我们什么数据;需要找到故障根源,它就能及时告知我们故障根源。 现实中的监控 可事实上并非如此,我们对监控寄予了太多,想到的就加上去,导致它越来越胖,越来越臃肿,但似乎并未解决我们的问题。 目前的监控平台和工具都很多,开源的、商业的、甚至我们自己开发的,但都

Mysql之我见一(基础知识)

1.Mysql简介 2.Mysql配置文件 3.Mysql逻辑架构 和其它数据库相比,Mysql有点与众不同,它的架构可以在多种不同的场景中应用并发挥良好作用。主要体现在存储引擎的架构上,插拔式存储架构将查询处理和其它的系统任务以及数据的存储提取相分离。这种架构可以根据业务的需求和实际需要选择合适的存储引擎。

WebService学习之我见

1.WebService是什么 (1)基于Web的服务,服务器端整出来一些资源让客户端应用访问(获取数据) (2)一个跨语言、跨平台的规范(抽象) (3)多个跨平台、跨语言的应用间通信整合的方案(实际) 以各个网站显示天气预报功能为例: 气象中心的管理系统将收集的天气信息并将数据暴露出来(通过WebService Server),而各大站点的应用就去调

oracle中常用连接之我见

测试脚本: 创建左表: createtable L asselect'left_1'as str,'1'as v from dualunionallselect'left_2'as str,'2'as v from dualunionallselect'left_3'as str,'3'as v from dualunionallselect'left_4'as str,'4

试探何为RSS之我见

RSS这个名词想必大家在浏览某些网站的时候会见过,例如腾讯公司的RRS网站:rss.qq.com/cq.htm等。那么我们会问,到底什么是RSS,它的作用是什么? 以下转自【百度百科】 RSS是站点和站点之间共享内容的一种简易方式【也称之为“聚合内容”】,是一种描述和同步网站内容的格式。 RSS可以是以下三个解释的其中一个: Really Simple Syndication

资深码农谈:嵌入式C语言可靠性设计之我见

前言   设备的可靠性涉及多个方面:稳定的硬件、优秀的软件架构、严格的测试以及市场和时间的检验等等。这里着重谈一下对嵌入式软件可靠性设计的一些理解,通过一定的技巧和方法提高软件可靠性。这里所说的嵌入式设备,是指使用单片机、ARM7、Cortex-M0,M3之类为核心的测控或工控系统。   嵌入式软件可靠性设计应该从防错、判错和容错三方面进行考虑. 此外,还需理解自己所使用的编译器特性。

“国产数据库”之我见

最近两年,无论是商业、开源,还是分布式、云原生等越来越多的国产数据库如雨后春笋般悄然出现在公众的视野。面对这么多种国产数据库,无论对客户还是我们数据库从业人员,难免会产生一种眼花缭乱的现象。令我喜出望外的是云和恩墨公司旗下的墨天轮社区在2019年6月推出了国产数据库流行度排行榜。下面三张图分别是排名“前十”的国产数据库从2019年6月到2021年1月的流行度排行趋势,最近三个月的国产数据库得分和热