本文主要是介绍FlashAttention-2 论文阅读笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本节详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。
算法
FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理matmul操作(尤其是在FP16/BF16格式下)时性能显著优化。该优化的目标是通过尽可能多地执行matmul操作来最大化GPU的吞吐量。
前向传播
-
在线Softmax技巧:FlashAttention-2对在线Softmax计算进行了修改,以最小化非matmul浮点操作:
- 避免通过
diag(ℓ(2))^-1
重新缩放输出更新的两个项。 - 维持一个“未缩放”的O(2)版本,并保留统计信息 ℓ(2)。
- 仅在循环结束时,通过
diag(ℓ(last))^-1
缩放最终的O(last)以获得正确的输出。
- 避免通过
-
最大化matmul FLOPs:为了最大化GPU的性能,FlashAttention-2重点优化了matmul操作,因为现代GPU上的专用单元(如Tensor Cores)在这些操作上表现出色。以Nvidia A100 GPU为例,其FP16/BF16 matmul的理论吞吐量可以达到312 TFLOPs/s,而非matmul FP32的吞吐量仅为19.5 TFLOPs/s。因此,FlashAttention-2通过优化算法,尽可能地减少非matmul操作,从而保持高吞吐量的执行效率。
-
算法细节:FlashAttention-2的前向传播通过以下步骤实现:
- 将输入矩阵Q、K、V分成大小为𝐵𝑟 × 𝑑的𝑇𝑟块,将输出矩阵O和logsumexp𝐿也相应地分块。
- 在每个线程块内部分配工作以最大化GPU资源的利用。
- 引入了在线Softmax技巧,通过有效管理和缩放中间结果,减少了不必要的计算开销。
反向传播
FlashAttention-2的反向传播与FlashAttention类似,但也有一些微调:
- 仅使用逐行logsumexp 𝐿,而不是softmax中的最大值和指数和。
- 使用类似的分块策略来优化计算和内存访问,以提高反向传播的效率和性能。
FlashAttention-2在并行性和工作分区方面进行了深入优化,以在GPU上实现更高的计算效率和性能。本节详细讨论了FlashAttention-2的并行化策略和工作分区方法。
并行性
前向传播
在FlashAttention-2中,前向传播的并行化策略如下:
-
线程块调度:每个注意力头使用一个线程块来处理,总共有batch size × number of heads个线程块。每个线程块被调度到一个流多处理器(SM)上执行。例如,Nvidia A100 GPU上有108个这样的SM。这种调度在大量线程块(如≥ 80)时非常高效,因为可以充分利用GPU的计算资源。
-
对长序列的优化:对于长序列(通常意味着较小的batch size或较少的头数),为了更好地利用GPU上的多处理器,FlashAttention-2额外并行化了序列长度维度。这在这种情况下显著提高了性能和效率。
反向传播
在反向传播中,为了避免在不同列块之间的共享计算,FlashAttention-2采用了类似的并行化策略:
- 线程块调度:每个列块使用一个线程块来处理。通过使用原子加操作来在不同线程块之间进行通信,以更新dQ,从而避免了共享内存的读写冲突。
工作分区
前向传播
在前向传播中,FlashAttention-2改进了工作分区策略,避免了FlashAttention中的"split-K"方案,具体包括:
- K和V的分割:FlashAttention-2将Q分割到4个线程束(warp)中,同时使得K和V对所有线程束可访问。每个线程束执行矩阵乘法以获取QK>的一部分,并将其与V的一部分相乘,从而获得对应输出的片段。这种改进减少了线程束之间的通信,降低了共享内存的读写次数,从而提升了性能。
反向传播
在反向传播中,为了避免"split-K"方案带来的同步问题,FlashAttention-2选择了适当的线程束分区策略,以优化计算和内存访问效率。
这篇关于FlashAttention-2 论文阅读笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!