本文主要是介绍通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
前言
成就本文有两个因素
- 第一个因素是,我带长沙的LLM项目团队做论文审稿GPT这个项目时,遇到了不少工程方面的问题(LLM方面的项目做多了,你会逐步发现,现在模型没啥秘密 技术架构/方向选型也不是秘密,最终都是各种工程细节的不断优化),比如数据的问题,再比如大模型本身的上下文长度的问题
前者已经得到了解决,详见此文《学术论文GPT的源码解读与微调:从ChatPaper到七月论文审稿GPT第1版》的第三部分
但后者相对麻烦些,原因在于审稿语料中一万多篇论文的长度基本都在万词以上,而通过本博客内之前的文章可以得知大部分模型的上下文长度基本都没超过8K模型 对应的上下文长度 论文审稿表现(凡是8K以内的长度均不够) GPT3.5 4-16K(后11.7日统一到了16K) 16K效果待测
另,23年11.7日开放了3.5的16K微调接口GPT4 8K-32K(后11.7日升级到128K) 待测 LLaMA 2048 LLaMA2 4096 LLaMA2-long(其23年9.27发的论文) 32K 效果待测 基于LongLoRA技术的LongAlpaca-7B/13B/70B 32K以上 效果待测 Baichuan-7B/13B、Baichuan 2-7B/13B 4096 ChatGLM-6B 2000 ChatGLM2-6B 8-32K 32K效果如何待定
- 第二个因素是,本文最初是作为ChatGLM2-6B的部分内容之一和第一代ChatGLM-6B的内容汇总在一块,而ChatGLM2-6B有一个比较突出的特点是其支持32K的上下文,而ChatGLM2是依据的FlashAttention技术实现的32K上下文(某种意义上降低了 attention的计算量,所以在同样的资源下可以算更长长度的attention)
所以为了阐述清楚FlashAttention、FlashAttention2等相关的原理,导致之前那篇文章越写越长,故特把FlashAttention相关的内容独立抽取出来成本文
至于LLaMA2-long和基于LongLoRA技术的LongAlpaca-7B/13B/70B,则分别见:《一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long》的最后部分、《大模型上下文长度的超强扩展:从LongLoRA到LongQLoRA(含源码剖析)》
本文会和本博客内其他大模型相关的文章一样,极其注重可读性
- 比如为了不断提高可读性,本文近期会不断反复修改,细抠标题的层级、措辞,甚至排版、标点符号,如果不通俗易懂,宁愿不写
- 如果你对某一节的某一个内容或某一个公式没看明白,请随时于本文评论下留言,一定及时修订以让君明白(友情提醒,本文假定大家已经熟悉了transformer,如果对transformer还不熟悉的话,建议先阅读此文:Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT,特别是其中的第三部分)
第一部分 Transformer的时空复杂度与标准注意力的问题
FlashAttention是斯坦福联合纽约州立大学在22年6月份提出的一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法「对应论文为:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,这是其GitHub地址」
它要解决一个什么样的问题呢?
- 首先,GPT3、LLaMA、ChatGLM、BLOOM等大语言模型输入输出的最大序列长度只有2048或4096,扩展到更长序列的难度在哪里呢?本质原因是,transformer模型的计算复杂度和空间复杂度都是 的,其中为序列长度
- 如此,FlashAttention提出了一种加速计算、节省显存和IO感知的精确注意力,可以有效地缓解上述问题
Meta推出的开源大模型LLaMA,阿联酋推出的开源大模型Falcon都使用了Flash Attention来加速计算和节省显存。目前,Flash Attention已经集成到了pytorch2.0中,另外triton、xformer等开源框架也进行了整合实现
1.1 Transformer计算复杂度——Self-Attention层与MLP层
简单理解的话,计算复杂度和序列长度的平方成正比,可以看一个小例子,比如两个相乘的矩阵大小分别为() 和(),矩阵乘法的一种计算方式是使用第一个矩阵的每一行与第二个矩阵的每一列做点乘
因为我们需要拿第一个矩阵的每一行去与第二个矩阵的每一列做点乘,所以总共就需要 次点乘。而每次点乘又需要 次乘法,所以总复杂度就为
精确理解的话,当输入批次大小为 ,序列长度为 时,
层transformer模型的计算量为 ,则代表词向量的维度或者隐藏层的维度(隐藏层维度通常等于词向量维度)
但这个结果是怎么一步一步计算得到的呢?下面,咱们来详细拆解这个计算过程
1.1.1 Self-Attention层的计算复杂度
首先,我们知道,transformer模型由 个相同的层组成,每个层分为两部分:self-attention块和MLP块
而self-attention层的模型参数有两部分,一部分是、、的权重矩阵、、和偏置,另一部分是输出权重矩阵和偏置,最终为:
具体怎么计算得来的呢?
- 第一步是计算、、
即
该矩阵乘法的输入和输出形状为
计算量为:- 计算
该部分的输入和输出形状为
计算量为:- 计算在上的加权
该部分矩阵乘法的输入和输出形状为
计算量为:- attention后的线性映射,矩阵乘法的输入和输出形状为
计算量为
最终自注意力层的输出结果为
1.1.2 MLP层的计算复杂度
MLP块由2个线性层组成,最终是
怎么计算得来的呢?
一般地,第一个线性层是将维度从映射到,第二个线性层再将维度从映射到
- 第一个线性层的权重矩阵 的形状为 ,相当于先将维度从 映射到,矩阵乘法的输入和输出形状为,计算量为
- 第二个线性层的权重矩阵 的形状为 ,相当于再将维度从 映射到 ,矩阵乘法的输入和输出形状为,计算量为
将上述所有表粗所示的计算量相加,得到每个transformer层的计算量大约为
1.1.3 logits的计算量:
此外,另一个计算量的大头是logits的计算(毕竟词嵌入矩阵的参数量也较多),将隐藏向量映射为词表大小,说白了,词向量维度通常等于隐藏层维度 ,词嵌入矩阵的参数量为,最后的输出层的权重矩阵通常与词嵌入矩阵是参数共享的「解释一下,如七月杜老师所说,这个是transformer中一个重要的点,参数共享可以减小参数量,词嵌入矩阵是[vocab_size,hidden_size],输出层矩阵是 [hidden_size,vocab_size],是可以共享的」
其矩阵乘法的输入和输出形状为,计算量为
因此,对于一个 层的transformer模型,输入数据形状为 的情况下,一次训练迭代的计算量为上述三个部分的综合,即:
1.2 Transformer的空间复杂度——Self-Attention层与MLP层
中间激活的显存大小为 ,其中 为注意力头数
大模型在训练过程中通常采用混合精度训练,中间激活值一般是float16或者bfloat16数据类型的。在分析中间激活的显存占用时,假设中间激活值是以float16或bfloat16数据格式来保存的,每个元素占了2个bytes。唯一例外的是,dropout操作的mask矩阵,每个元素只占1个bytes。在下面的分析中,单位是bytes,而不是元素个数。
每个transformer层包含了一个self-attention块和MLP块,并分别对应了一个layer normalization连接。
1.2.1 Self-Attention块的中间激活
self-attention块的计算公式如下:
最终,self-attention块的中间激活占用显存大小为:
具体怎么计算得来的呢?
- 对于 ,需要保存它们共同的输入 ,这就是中间激活。输入 的形状为,元素个数为 ,占用显存大小为
- 对于 矩阵乘法,需要保存中间激活 ,两个张量的形状都是,占用显存大小合计为
- 对于 函数,需要保存函数的输入 ,占用显存大小为,这里的 表示注意力头数
其中
的形状为:
的形状为:
的形状为:,元素个数为,占用显存大小为
如我司论文100课的一学员“饭饭”所说:每一个token相对于其他token的注意力权重,所以每个token都有N个权重,那么所有token就是N²。 再,每个注意力头,都有这样一套注意力矩阵,所以是N²a,再乘以batch和fp16- 计算完 函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与 相同,占用显存大小为
- 计算在 上的attention,即,需要保存 ,大小为 ;以及 ,大小为 ,二者占用显存大小合计为
- 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为;dropout需要保存mask矩阵,大小为,二者占用显存大小合计为
因此,将上述中间激活相加得到,self-attention块的中间激活占用显存大小为
1.2.2 MLP块的中间激活
MLP块的计算公式如下:,最终对于MLP块,需要保存的中间激活值为
具体怎么计算得来的呢?
- 第一个线性层需要保存其输入,占用显存大小为
- 激活函数需要保存其输入,占用显存大小为
- 第二个线性层需要保存其输入,占用显存大小为
- 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为
1.2.3 两个layer norm需要保存的中间激活
另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为,2个layer norm需要保存的中间激活为
综上,每个transformer层需要保存的中间激活占用显存大小为
对于 层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度 比较大,层数 较深时,这部分的中间激活是很少的,可以忽略
因此,对于 层transformer模型,中间激活占用的显存大小可以近似为 「更多分析见此文《分析transformer模型的参数量、计算量、中间激活、KV cache》」
通过上面两小节的内容,可以看到,transformer模型的计算量和储存复杂度随着序列长度 呈二次方增长。这限制了大语言模型的最大序列长度 的大小
其次,GPT4将最大序列长度 扩大到了32K,Claude更是将最大序列长度 扩大到了100K,这些工作一定采用了一些优化方法来降低原生transformer的复杂度,那具体怎么优化呢?
我们知道,每个transformer层分为两部分:self-attention块和MLP块,但上面计算量中的 项和中间激活中的 项都是self-attention块产生的,与MLP块无关
1.3 标准注意力Standard Attention的两个问题:显存占用多、HBM读写次数多
- 回顾一下,transformer中注意力机制的计算过程为 (再次提醒,如果对transformer相关细节有所遗忘,建议先看此:Transformer通俗笔记,如果忘了什么是softmax,则回顾下此文:如何通俗理解Word2Vec):
其中, ,其中 是序列长度, 是每个注意力头的维度,输出可以记为 - 上面的式子可以拆解为以下三步 在标准注意力实现中, 都要写回到HBM中(下文很快会解释这个HBM),占用了 的内存,通常
例如,对于GPT2, , ;对于GPT3,,
总之,注意力矩阵 需要的内存 远大于 所需要的内存 -
下图展示了标准注意力的实现过程
其中,一共包含八次HBM的矩阵读写操作。这八次读写操作分别为:
第一行对 的读 共两次,对 的写一次,读写操作总共三次
第二行对 读一次,对 写一次,读写操作总共两次
第三行对 的读 共两次,对 的写一次,读写操作总共三次
补充一下背景知识
- 尽管已经有许多近似注意力的方法尝试减少attention的计算和内存要求。例如,稀疏近似和低秩近似的方法,将计算复杂度降低到了序列长度的线性或亚线性
- 但这些近似注意力方法方法并没有得到广泛应用。因为这些方法过于关注FLOPS(浮点数计算次数)的减少,而忽略了IO读写的内存访问开销,导致这并没有效减少运行时间(wall-clock time)
- 总之,在现代GPU中,计算速度已经远超过了显存访问速度,transformer中的大部分计算操作的瓶颈是显存访问。对于显存受限的操作,IO感知是非常重要的,因为显存读写占用了大部分的运行时间
GPU的内存由多个不同大小和不同读写速度的内存组成。内存越小,读写速度越快。对于A100-40GB来说,内存分级图如下所示
- SRAM内存分布在108个流式多处理器上,每个处理器的大小为192K,合计为
相当于计算块,但内存小- 高带宽内存HBM(High Bandwidth Memory),也就是我们常说的显存,大小为40GB。SRAM的读写速度为19TB/s,而HBM的读写速度只有1.5TB/s,不到SRAM的1/10
相当于计算慢,但内存大
总之,transformer的核心组件self-attention块的计算复杂度和空间复杂度是序列长度 的二次方
且对于self-attention块,除了两个大矩阵乘法是计算受限的(、),其他都是内存受限的逐点运算( 例如对 的mask操作、 的softmax操作、对 的dropout操作,这些逐点操作的性能是受限于内存带宽的,会减慢运行时间)
即标准注意力实现存在两个问题:
- 显存占用多,过程中由于实例化了完整的注意力矩阵 ,导致了 的内存要求
- HBM读写次数多,减慢了运行时间(wall- clock time)
接下来2.1节中的Memory-efficient Attention、2.2节中的Flash Attention,便是要分别解决上述这两个问题
第二部分 FlashAttention的前向传递:Memory-efficient Attention/Flash Attention
2.1 Memory-efficient Attention:把显存复杂度从平方降低到线性,但HBM访问次数仍是平方
在注意力计算过程中,节省显存的主要挑战是softmax与的列是耦合的。其方法是单独计算softmax的归一化因子,来实现解耦
- 为了简化分析,忽略计算softmax时“减去最大值”的步骤
记 的第 列为 , 的第 列为 ,有
定义softmax的归一化因子为: - 记 为 的第 个列向量,则输出 的第 个列向量 为:
- 在计算得到归一化因子 后,就可以通过反复累加 来得到
如此,通过节省显存(memory-efficient)的注意力机制,改变了计算顺序,相比于Standard Attention,节省显存的注意力机制将显存复杂度从 降低到了
这种方法在《Online normalizer calculation for softmax》和《Self-attention Does Not Need Memory》中已经使用过,称其为“lazy softmax”,这种方法避免了实例化完整的注意力矩阵 ,从而达到了节省显存的目的。然而HBM访问次数仍然是 的,因此运行时间并没有减少
2.2 Flash Attention:通过kernel融合降低HBM读写次数,避免频繁地从HBM中读写数据
如上文说过的
- 在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写 的矩阵是影响性能的主要瓶颈
- 稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)
- 针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据
所以,减少对HBM的读写次数,有效利用更高速的SRAM来进行计算是非常重要的,而对于性能受限于内存带宽的操作,进行加速的常用方式就是kernel融合,该操作的典型方式分为三步:
- 每个kernel将输入数据从低速的HBM中加载到高速的SRAM中
- 在SRAM中,进行计算
- 计算完毕后,将计算结果从SRAM中写入到HBM中
如此,便可避免反复执行“从HBM中读取输入数据,SRAM执行计算,最后将计算结果写入到HBM中”,将多个操作融合成一个操作,减少读写HBM的次数(需要注意的是,模型训练通常会影响到算子融合的效果,因为为了后向传递计算梯度,通常需要将某些中间结果写入到HBM中)
可能有的同学对上面的阐述不甚理解,其实原理很简单,即如下两句话
- 如果把SRAM写回HBM只是为了(重新)加载它来计算softmax
- 那么是可以将其保存在SRAM中,执行所有中间步骤,然后将最终结果写回HBM
前者如下图左侧所示,后者如下图右侧所示(下图图源)
2.2.1 全面阐述分块计算注意力tiling——kernel融合需满足SRAM的内存大小,但无奈SRAM内存太小
虽然通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但有个问题是
- SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小
相当于,内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算,因此分块大小block_size不能太大,否则会导致OOM - 而分块计算的难点是什么呢?
注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大
怎么理解上文中的这句“由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大”呢?
先回顾一下softmax的计算公式
- 考虑到向量 ,原生softmax的计算过程如下:
- 在实际硬件中,浮点数表示的范围是有限的
对于float32和bfloat16来说,当 时,就会变得很大甚至变成inf,发生数据上溢的问题
故为了避免发生数值溢出的问题,保证数值稳定性,计算时通常会“减去最大值”,称为“safe softmax”
而便被定义为中的最大值
从而,现在所有的深度学习框架中都采用了“safe softmax”这种计算方式- 在训练语言模型时,通常会采用交叉熵损失函数。交叉熵损失函数等价于先执行log_softmax函数,再计算负对数似然函数
且在计算log_softmax时,同样会执行“减去最大值”,这不仅可以避免数值溢出,提高数值稳定性,还可以加快计算速度
总之,要计算输入序列中的特定第个标记对序列中其他标记的关注程度,需要在SRAM中随时可用所有这些分数(这里用表示),但是SRAM的容量是有限的,(序列长度)可以是1000甚至100000个token,会爆炸得很快
总之,tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算,softmax与 的列是耦合的,通过引入了两个额外的统计量 来进行解耦(前者类似最大分数,后者类似exp分数总和),实现了分块计算
2.2.1.1 通过23个公式全面理解分块计算注意力tiling
我们从头开始,全面梳理下(以下23个公式的阐述修改自此)
-
考虑到向量 ,原生softmax的计算过程如下:
其中,分子对向量 中的第 个元素取指数,分母则是对向量中的所有元素取指数后的和,这确保了softmax 函数的输出是一个概率分布,即所有元素的和为1
便被定义为中的最大值
是一个新的向量,其中每一项相当于在公式4的标准softmax的分子即的每一项的基础上,在其指数项中减去了一个中的最大值
是 softmax 分母中的求和项,为了后面方便描述,下文将公式7中的求和项称为“EXP求和项”
考虑一个大小为2d的向量 ,将其“一切为二”进行分块:
其中
换言之,子向量是原向量 的前半部分,子向量是原向量 的后半部分
假设在分块计算中先处理 ,再处理
那就先使用公式5至公式8对子向量计算它的“局部”,计算过程如下公式9-12所示
很明显,至此得到的并不能算是子向量的最终结果,原因很简单
一者,公式10中的指数项减去的最大值应该是整个向量的最大值,而不应该是子向量的最大值
二者,公式12中分母的EXP求和项应该是关于整个向量的求和项,而非仅仅只是子向量中所有元素的求和项
正因上述计算得到的 不是最终结果,所以将其称为“局部的”
接下来将介绍通过保存额外的一些变量值,在处理完 后更新 的 值的方法
首先,在处理完子向量 后,保存 和 ,相比于保存整个子向量,仅保存这两个标量的开销要小的多
其次,还需要保存两个全局标量: 和
表示当前最大值,因为目前只处理完了 ,所以暂:
表示全局EXP求和项。因为目前只处理完了,所以暂:
接着采用类似处理 的方法来处理,可得如下结果:-
同样道理,此时公式16得到的softmax也是局部而非全局的
但在处理完 之后,可以利用的信息来更新之前保存的两个全局标量 ()和 (),如下公式17和18所示: -
公式17的含义很简单:更新后的全局最大值就是之前的最大值 和 的最大值中更大的那一个 -
公式18是更新的全局EXP求和项的方法
且慢,这是怎么来的呢?不应该是?
以为例, 我们说是“局部的”是因为 到目前为止只用到了的信息, 将 更新至“全局”需要用到把的计算公式15即稍微展开可得:
-
可知导致是“局部”而非“全局”的原因是它减去的max值是“局部的”,所以只需要将这个max值替换为全局的即可
为此可以将 做下变换,以变成全局 - 即
此时的 更新为了:“全局的”
这个公式说明,当需要把某个 更新为“全局的”时,只要将其乘以一个项: ,其中 表示当前对应的最大值, 表示当前最大值
回到公式18,可知其首先用了这种全局更新方法分别将 与更新至全局,然后将它们求和得到当前的EXP求和项
基于上述更新的方法,也能直接更新softmax值
根据公式16即,可知
由于当前的分子和分母都是局部的,所以都需要更新至全局
先看分子部分,由公式14定义即,可将其做下更新 -
即
当对比变换前后,再次印证上面针对公式20所得的结论,即:如想把从局部值变成全局值, 只要将其乘以一个项: ,其中 表示当前对应的最大值, 表示当前最大值
再来看分母部分,我们其实只需要将分母由替换为 即可,这可以由如下公式办到: -
其中的由公式18计算得到
好,问题来了
问题1 网上很多朋友也对此表达过疑惑,即为何公式22这里的分母是而非
答:原因很简单,考虑一下为什么我们使用softmax:它为向量的每一个元素分配一个介于0和1之间的概率值,使得这些概率的总和为1
当我们说"全局",是希望为整个数据集的每一个元素分配概率,而不仅仅是为数据集的一个子集分配
所以当你有一个数据流,分成了两部分:和 。你首先看到 并计算了它的softmax,然后,你看到了 ,为了计算整个数据流(和 合并)的softmax,你不能只单独考虑,你必须考虑 和 合并后的全局效果
接下来 问题 可能又来了,可能马上有同学问
问题2 公式20中的不说是全局的么?
答:公式20中的只是 的全局版本,且它依然只考虑了这个子集下的所有数据,没有考虑整个全部的数据块
问题3 公式20和公式19都只用到了,那它两啥区别
公式19:
这里的最大值是,即的局部最大值。这意味着对于这个数据块,我们将每个元素与其内部的最大值进行比较
公式20:
这里的最大值是 ,它是和 的全局最大值。这意味着我们将的每个元素与所有迄今为止观察到的元素中的最大值进行比较
所以,他们主要的区别是它们使用的参考最大值不同:公式19使用的是局部最大值,而公式20使用的是更全局的最大值。这种变换是为了数值稳定性,确保当我们计算e的指数时不会遇到数值上溢的问题
最后,结合公式21和公式22,的更新可由如下实现: -
仔细看公式23,我们在更新的值时,用到了前面提到的额外保存的几个量:
的局部值,来自公式16
的局部EXP求和项,来自公式15
的局部最大值,来自公式13
全局最大值,来自公式17
全局EXP求和项,来自公式18
同理,可以将上面前三项中的 替换成 来对 的 值进行更新,所有更新过程都不需要用到 或 的向量值
这就是FlashAttention中对值进行动态更新的本质
上述其实是一个增量计算的过程
- 我们首先计算一个分块的局部softmax值,然后存储起来
- 当处理完下一个分块时,可以根据此时的新的全局最大值和全局EXP求和项来更新旧的softmax值,接着再处理下一个分块,然后再更新
- 当处理完所有分块后,此时的所有分块的softmax值都是“全局的”
2.2.1.2 对分块计算注意力tiling的简单总结
可能你的CPU已经干烧了,为缓解烧脑,咱们最后再通过一个简单的例子把上述过程总结一下
对于两个向量 ,解耦拼接向量 的softmax计算:
通过保持两个额外的统计量 ,可以实现softmax的分块计算。需要注意的是,可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能,多个block的计算不是串行(sequential)的, 而是并行的
我貌似看到了你脸上隐约有点焦虑的情绪,没事 不急 July懂,单纯的公式毕竟相对晦涩,下面通过一个例子来形象的说明到底是如何分块计算softmax的
对向量 [1,2,3,4] 计算softmax,分成两块 [1,2] 和 [3,4] 进行计算
计算block 1:
计算block 2:
合并得到完整的softmax结果:
2.2.1.3 Flash Attention算法的前向计算算法
在忽略mask和dropout的情况下,简化分析,Flash Attention算法的前向计算过程如下所示
从上图可以看到,该算法在的维度上做外循环,在 的维度上做内循环(而在triton的代码实现中,则采用了在 的维度上做外循环,在 的维度上做内循环)
为本着细致起见,还是针对上述16行代码一行一行解释下,为方便大家理解,再引用知乎上marsggbo画的一个流程图,大家可以对照这个流程图增进对相关代码的理解
首先,有基本条件:
其中,是序列长度, 是每个注意力头的维度,的大小为
Set block sizes ,
计算行/列块大小。为什么ceil() ?因为查询、键和值向量是维的,所以我们还需要将它们组合成输出的维向量。所以这个大小基本上允许我们用q k v和0个向量最大化SRAM的容量以GPT2和A100为例:
A100的SRAM大小为
GPT2中,,对应的的维度为,中间结果的维度为
故
用全0初始化输出矩阵,它将作为一个累加器
类似上文的,其目的是保存softmax的累积分母——exp分数的总和
类似上文的,其逐行保存最大分数,且初始化为-inf,因为我们将对其进行Max运算符,因此无论第一个块的Max是什么,它肯定大于-inf
按照步骤1中的块大小,将, 和分成块
具体来说,则是
沿着行方向分为块,每一分块的大小为
沿着行方向分为块,每一分块的大小为
而
将分割成块
其中,与的块大小相同,也是沿着行方向分为块,每一分块的大小为
至于向量和向量则分为块,每一块子向量大小为
综合上述3、4两个步骤,可以得到各个分块之间的关系如下for 1 ≤ j ≤ Tc do
开始跨列循环(即外部循环,由控制,从上一列到下一列),即跨键/值向量,即遍历,一共循环次Load Kj , Vj from 慢速HBM to on-chip 快速SRAM.
将和块从HBM加载到SRAM(它们的大小为)。在这个时间点上我们仍然有50%的SRAM未被占用(专用于和)
for 1 ≤ i ≤ Tr do
开始跨行内部循环(从上一行到下一行),即跨查询向量,一共循环次,可只在遍历Load Qi , Oi, ℓi, mi from HBM to on-chip SRAM.
将 ()和 ()块以及()和 ()加载到SRAM中这里需要保证和能够载入SRAM(包括所有中间变量)
On chip, compute ,即为
这一步计算 ()和转置()之间的点积,得到分块的Attention Score,在标准的Transformer计算中得到的Attention Score是一个 的矩阵,如下图所示(图中, , )
当,遍历
当,遍历On chip, compute, ,
使用上一步计算的分数计算、和
对分块的Attention Score ,计算它每一行中的最大值基于,计算指数项(归一化-取行最大值并从行分数中减去它,然后EXP):
然后再基于,计算EXP求和项(矩阵的逐行和):
On chip, compute 、
这一步是计算和,举个例子,如下图所说:
包含之前所有块的逐行最大值(j=1 & j=2,用绿色表示),包含当前块的逐行最大值(用黄色表示)。为了得到我们只需要在和之间取一个最大值,也类似
和上文利用公式17即和18即分别更新和,是一个意思Write
为了更好地理解这一行的公式,首先得明白多行一起计算的目的是Batch计算
例如在上上图中,每一个小分块 有多行(图中为3行),但行与行之间的数据不会有任何的交互,只是一种Batch计算的策略。真正的分块意义是在列上,因为softmax是沿着列方向进行的所以为了方便理解,可以想象为 等于1,即每一次只计算上上图中的一个大小为 的分块
基于上述的简化方法,接下来看整个softmax的更新过程。我们用 来表示每一行的Attention Score,用 表示每一行的
因为现在不考虑Batch计算了,所以每一次处理的Attention Score都是一个向量,如上图中的 ,我们首先用公式5至公式8计算它的局部
得到 ,此时中只有前两个位置有值,对应的是的局部 值然后用相同的方法处理它下方的每一行(绿色部分的前两列)
接着处理 ,同理首先用公式5至公式8计算它的局部,然后用公式23即对 进行更新(注意,通过上面第11行,可知即等同于):
(记为公式24)
其中 等价于公式6即的结果
当处理到 时,继续套用公式24来更新即可:
(记为公式25)
下面再进一步,直接尝试来更新输出 ,而不仅仅是值。方法其实很简单,只要在每次动态更新完 ,乘上其对应的 的值即可:
(记为公式26)
其中 对应的是 中的列数(2)
拿着公式26与上面的伪代码进行对比,可知伪代码中的公式仅仅是公式26的矩阵版本。到此,可以看到用公式26即可实现分块的Self-Attention计算
Write
更新和end for
end for
Return O.
2.2.2 重计算
上文讲到,模型训练会影响kernel融合的效果,为了后向传递计算梯度,前向计算时通常需要将某些中间结果写回到HBM中,这会产生额外的HBM读写次数,减慢运行时间。因此,Flash Attention没有为后向传递保存很大的中间结果矩阵
在标准注意力实现中,后向传递计算 的梯度时,需要用到 的中间矩阵 ,但这两个矩阵并没有保存下来。这里的技巧是重计算,保存了两个统计量,后向传递时在高速的SRAM上快速地重新计算Attention,通过分块的方式重新计算注意力矩阵。相比于标准注意力中,从HBM中读取很大的中间注意力矩阵的方法,重计算的方法要快得多。
总的来说,Flash Attention通过调整注意力的计算顺序,引入两个额外的统计量进行分块计算,避免了实例化完整的 的注意力矩阵,将显存复杂度从 降低到了 。另外,对于内存受限的标准注意力,Flash Attention通过kernel融合和分块计算,大量减少了HBM访问次数,尽管由于后向传递中的重计算增加了额外的计算量FLOPs,减少了运行时间,计算速度更快(GPT2的7.6)
2.2.3 kernel融合
为了简化分析,上文介绍注意力时忽略了mask和dropout操作。下面详细介绍Flash Attention前向传递的细节。给定输入,计算得到注意力输出
其中, 是softmax的缩放因子,典型的比如 。MASK操作将输入中的某些元素置为 −∞ ,计算softmax后就变成了0,其他元素保持不变
causal-lm结构和prefix-lm结构的主要差别就是MASK矩阵不同。逐点作用在 的每个元素上,以 的概率将该元素置为0,以 的概率将元素置为
tiling分块计算使得我们可以用一个CUDA kernel来执行注意力的所有操作。从HBM中加载输入数据,在SRAM中执行所有的计算操作(矩阵乘法、mask、softmax、dropout、矩阵乘法),再将计算结果写回到HBM中。通过kernel融合将多个操作融合为一个操作,避免了反复地从HBM中读写数据
kernel融合如下图所示,图片来源于https://www.bilibili.com/video/BV1Zz4y1q7FX/
考虑mask和dropout操作,完整Flash Attention算法的前向计算过程如下所示:
// 待更..
第三部分 FlashAttention2
// 待更
参考文献与推荐阅读
- Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT
- 分析transformer模型的参数量、计算量、中间激活、KV cache
- FlashAttention:加速计算,节省显存, IO感知的精确注意力
- FlashAttention 的速度优化原理是怎样的?,其中Civ、marsggbo回答的均不错
- FlashAttention图解(如何加速Attention)、FlashAttention算法详解
- 图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑
创作与修订记录
- 10.6,在《ChatGLM两代的部署/微调/实现》一文中阐述「FlashAttention的原理与结构:减少内存访问提升计算速度」时,感觉会越写越长,故把FlashAttention相关的内容放到本新一篇博客里
- 10.7,主要修订第一部分
- 10.8,主要修订第二部分的2.2节
- 10.9,反复修订2.2节,以最大程度的提高可读性
反复修订2.2.1.1节:通过23个公式全面理解分块计算注意力tiling
反复修订2.2.1.3节:Flash Attention算法的前向计算算法 - 12.27,在“1.2.1 Self-Attention块的中间激活”节中新增一个说明解释
- 12.28,在“Flash Attention算法的前向计算算法”节中补充一个对理解该算法外循环、内循环很重要的两张图
并把本文的标题改成最新的:《通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度》
这篇关于通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!