本文主要是介绍长上下文训练的关键因素(2)-flash-attention,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
上一篇在这 长上下文训练的关键因素(1) (qq.com)
我看有读者留言说想看Flash-attention,那么今天就讲它
这东西算法其实全讲老复杂了,我们挑着好理解方式讲
我们上节课讲到了计算复杂度的关系,就是以下这部分,先来复习一下
QKV都是由self-attention,也就是由输入数据self出来的,输入的数据[B,n,D]的最后一维和Wq,Wk,Wv三个矩阵是相等的。
Wq矩阵=[D,D],Wk和Wv也都一样, 然后输入数据分别和Wq,Wk,Wv点乘出来QKV3个值。
拿Q举例,输入数据[B,n,D]要和q矩阵[D,D]点积,生成出来的东西是[B,n,D],K和V也一样。
然后按照公式来计算
Q*K的转置,就是[B,n,D]*[B,n,D],就是[B,n,n]
如果考虑到多头注意力因素就是[B,header_number,n,head_dim]*
[B,header_number,n,head_dim]=[B,header_number,n,n],都差不多
不管是不是多头,点乘的计算量都是B*n^2*D,D固定的,B是batch_size,也不用考虑,所以计算量相关性最大的变数就剩下n^2了,n越大,肯定平方值越大,复杂度越高,这个就是算力复杂度和n的平方相关的由来。
至于后面QK的softmax结果和V相乘也是一样的,但是因为他俩是2次的顺序计算,所以还是和n的平方相关,也是transformer最被诟病的地方。
然后上一篇结尾我也说了,Flash-attention和别的实现厂商长下文的,实现思路不一样,那它为啥不一样呢?
首先说它干了什么事
1. 离GPU的计算资源最近
2. 摒弃了反复在HBM和SRAM,GPU之前反复导数据
3. 省内存
4. 算力呢?算力到没省...
我们一个一个看。<
这篇关于长上下文训练的关键因素(2)-flash-attention的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!