LLM - GPT(Decoder Only) 类模型的 KV Cache 公式与原理 教程

2024-08-28 12:12

本文主要是介绍LLM - GPT(Decoder Only) 类模型的 KV Cache 公式与原理 教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141605718

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


Img

在 GPT 类模型中,KV Cache (键值缓存) 是用于优化推理效率的重要技术,基本思想是通过缓存先前计算的 键(Key) 和 值(Value),避免在推理过程中,重复计算 Mask 的 注意力(Attention) 矩阵,从而加速生成过程。

1. 公式

矩阵乘法的基础性质:

A ⋅ B = [ A 1 A 2 … A n ] ⋅ [ B 1 B 2 ⋮ B n ] = A 1 B 1 + A 2 B 2 + ⋯ + A n B n A \cdot B = \begin{bmatrix} A_{1} & A_{2} & \dots & A_{n} \end{bmatrix} \cdot \begin{bmatrix} B_{1} \\ B_{2} \\ \vdots \\ B_{n} \end{bmatrix} = A_{1}B_{1} + A_{2}B_{2} + \dots + A_{n}B_{n} AB=[A1A2An] B1B2Bn =A1B1+A2B2++AnBn

其中 A i A_{i} Ai A A A 的列向量, B i B_{i} Bi B B B 的行向量,也就是说相同维度的向量相乘,可拆解成行向量乘以列向量,即 A A A n n n 列, B B B n n n 行。如图:

matrix

例如:基础的矩阵乘法:

A = [ 1 2 3 4 ] , B = [ 5 6 7 8 ] C = [ 1 ∗ 5 + 2 ∗ 7 1 ∗ 6 + 2 ∗ 8 3 ∗ 5 + 4 ∗ 7 3 ∗ 6 + 4 ∗ 8 ] = [ 19 22 43 50 ] A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}, \quad B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \\ C = \begin{bmatrix} 1*5 + 2*7 & 1*6 + 2*8 \\ 3*5 + 4*7 & 3*6 + 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} A=[1324],B=[5768]C=[15+2735+4716+2836+48]=[19432250]

也可以写成,行列向量相乘的形式,即 A 拆分出多个行向量,B 拆分出多个列向量,即:

C = [ 1 3 ] ⋅ [ 5 6 ] + [ 2 4 ] ⋅ [ 7 8 ] = [ 1 ∗ 5 1 ∗ 6 3 ∗ 5 3 ∗ 6 ] + [ 2 ∗ 7 2 ∗ 8 4 ∗ 7 4 ∗ 8 ] C = \begin{bmatrix} 1 \\ 3 \end{bmatrix} \cdot \begin{bmatrix} 5 & 6 \end{bmatrix} + \begin{bmatrix} 2 \\ 4 \end{bmatrix} \cdot \begin{bmatrix} 7 & 8 \end{bmatrix}= \begin{bmatrix} 1*5 & 1*6 \\ 3*5 & 3*6 \end{bmatrix} + \begin{bmatrix} 2*7 & 2*8 \\ 4*7 & 4*8 \end{bmatrix} C=[13][56]+[24][78]=[15351636]+[27472848]
= [ 1 ∗ 5 + 2 ∗ 7 1 ∗ 6 + 2 ∗ 8 3 ∗ 5 + 4 ∗ 7 3 ∗ 6 + 4 ∗ 8 ] = [ 19 22 43 50 ] =\begin{bmatrix} 1*5 + 2*7 & 1*6 + 2*8 \\ 3*5 + 4*7 & 3*6 + 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} =[15+2735+4716+2836+48]=[19432250]

进一步拆解:

A ⋅ B = A 1 B 1 + A 2 B 2 + ⋯ + A n B n = [ a 1 , 1 B 1 a 2 , 1 B 2 ⋮ a m , 1 B n ] + [ a 1 , 2 B 1 a 2 , 2 B 2 ⋮ a m , 2 B n ] + ⋯ + [ a 1 , n B 1 a 2 , n B 2 ⋮ a m , n B n ] = [ a 1 , 1 B 1 + a 1 , 2 B 1 + ⋯ + a 1 , n B 1 a 2 , 1 B 2 + a 2 , 2 B 2 + ⋯ + a 2 , n B 2 ⋯ a m , 1 B n + a m , 2 B n + ⋯ + a m , n B n ] A \cdot B = A_{1}B_{1} + A_{2}B_{2} + \dots + A_{n}B_{n} \\ = \begin{bmatrix} a_{1,1}B_{1} \\ a_{2,1}B_{2} \\ \vdots \\ a_{m,1}B_{n} \end{bmatrix} + \begin{bmatrix} a_{1,2}B_{1} \\ a_{2,2}B_{2} \\ \vdots \\ a_{m,2}B_{n} \end{bmatrix} + \cdots + \begin{bmatrix} a_{1,n}B_{1} \\ a_{2,n}B_{2} \\ \vdots \\ a_{m,n}B_{n} \end{bmatrix} \\ = \begin{bmatrix} a_{1,1}B_{1} + a_{1,2}B_{1} + \cdots + a_{1,n}B_{1} \\ a_{2,1}B_{2} + a_{2,2}B_{2} + \cdots + a_{2,n}B_{2} \\ \cdots \\ a_{m,1}B_{n} + a_{m,2}B_{n} + \cdots + a_{m,n}B_{n} \end{bmatrix} AB=A1B1+A2B2++AnBn= a1,1B1a2,1B2am,1Bn + a1,2B1a2,2B2am,2Bn ++ a1,nB1a2,nB2am,nBn = a1,1B1+a1,2B1++a1,nB1a2,1B2+a2,2B2++a2,nB2am,1Bn+am,2Bn++am,nBn

基础的矩阵乘法的另一种形式:

C = [ 1 3 ] ⋅ [ 5 , 6 ] + [ 2 4 ] ⋅ [ 7 , 8 ] C=\begin{bmatrix} 1 \\ 3 \end{bmatrix} \cdot \begin{bmatrix} 5,6 \end{bmatrix} + \begin{bmatrix} 2 \\ 4 \end{bmatrix} \cdot \begin{bmatrix} 7,8 \end{bmatrix} C=[13][5,6]+[24][7,8]
[ 1 ∗ [ 5 6 ] 3 ∗ [ 5 6 ] ] + [ 2 ∗ [ 7 8 ] 4 ∗ [ 7 8 ] ] \begin{bmatrix} 1*[5&6] \\ 3*[5&6] \end{bmatrix} + \begin{bmatrix} 2*[7&8] \\ 4*[7&8] \end{bmatrix} [1[53[56]6]]+[2[74[78]8]]
[ 1 ∗ 5 1 ∗ 6 3 ∗ 5 3 ∗ 6 ] + [ 2 ∗ 7 2 ∗ 8 4 ∗ 7 4 ∗ 8 ] = [ 19 22 43 50 ] \begin{bmatrix} 1*5 & 1*6 \\ 3*5 & 3*6 \end{bmatrix} + \begin{bmatrix} 2*7 & 2*8 \\ 4*7 & 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} [15351636]+[27472848]=[19432250]

如果 A A A 是下三角矩阵,即包含 Mask 信息,Decoder 无法观察到之后的推理部分,则 A ⋅ B A \cdot B AB,输出:

A ⋅ B = [ a 1 , 1 B 1 a 2 , 1 B 2 + a 2 , 2 B 2 ⋯ a m , 1 B n + a m , 2 B n + ⋯ + a m , n B n ] A \cdot B = \left[ \begin{array}{llll} a_{1,1}B_{1}\\ a_{2,1}B_{2} + a_{2,2}B_{2}\\ \cdots \\ a_{m,1}B_{n} + a_{m,2}B_{n} + \cdots + a_{m,n}B_{n} \end{array} \right] AB= a1,1B1a2,1B2+a2,2B2am,1Bn+am,2Bn++am,nBn

2. 推理

第1步:

在 Decoder 解码过程中,只关注 Transformer 的 自注意力(Self-Attention),输入第 1 个 Token,将 Token 转换成 输入特征 I n p u t 1 = [ 1 , d e m b ] Input_{1}=[1,d_{emb}] Input1=[1,demb],暂时忽略 batch_size d e m b d_{emb} demb 表示 Embedding Size。

  1. 输入特征 I n p u t 0 = [ 1 , d e m b ] Input_{0}=[1,d_{emb}] Input0=[1,demb],乘以权重 W = [ d e m b , 3 ∗ d e m b ] W=[d_{emb}, 3*d_{emb}] W=[demb,3demb] (已训练完成,值是固定的),输出维度 [ 1 , 3 ∗ d e m b ] [1, 3*d_{emb}] [1,3demb],即作为 Q\K\V,每个向量 [ 1 , d e m b ] [1,d_{emb}] [1,demb]

    • Q 1 = [ 1 , d e m b ] Q_{1}=[1,d_{emb}] Q1=[1,demb] K 1 = [ 1 , d e m b ] K_{1}=[1,d_{emb}] K1=[1,demb] V 1 = [ 1 , d e m d ] V_{1}=[1,d_{emd}] V1=[1,demd],只与输入特征 I n p u t 0 Input_{0} Input0 的 Embedding 相关。
  2. 根据 Self-Attention 的公式,忽略 d \sqrt{d} d ,只有1维,mask 不起作用,即
    A t t ( Q , K , V ) = s o f t m a x ( Q K ⊤ + m a s k ) ∗ V A t t 1 ( Q , K , V ) = s o f t m a x ( Q 1 K 1 ⊤ ) V 1 其中  s o f t m a x ( x i ) = e x i ∑ j = 1 n e x j Att(Q,K,V)=softmax(QK^{\top}+mask)*V \\ Att_{1}(Q,K,V)=softmax(Q_{1}K_{1}^{\top})V_{1} \\ 其中 \ softmax(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} Att(Q,K,V)=softmax(QK+mask)VAtt1(Q,K,V)=softmax(Q1K1)V1其中 softmax(xi)=j=1nexjexi

  3. A t t 0 Att_{0} Att0 ( [ 1 , d e m b ] [1,d_{emb}] [1,demb]) 经过一系列推理,最后输出 [ 1 , d v ] [1, d_{v}] [1,dv] d v d_{v} dv 是全部词元 Token 的数量,根据概率值即可获得最后的 Token。

第 2 步

将第 1 步输出的 Token 转换成 [ 1 , d e m b ] [1,d_{emb}] [1,demb],与第 1 步组合至一起,即 输入特征 I n p u t 1 = [ 2 , d e m b ] Input_{1}=[2,d_{emb}] Input1=[2,demb]

  1. 输入特征 I n p u t 1 = [ 2 , d e m b ] Input_{1}=[2,d_{emb}] Input1=[2,demb],乘以权重 W = [ d e m b , 3 ∗ d e m b ] W=[d_{emb}, 3*d_{emb}] W=[demb,3demb],权重是固定的,因此只需要计算第 2 个输入的特征 [ 1 , d e m b ] [1,d_{emb}] [1,demb],第 1 个不需要计算,也就是说 Q\K\V 的维度是 [ 2 , d e m b ] [2, d_{emb}] [2,demb],只需计算一次即可,剩余的可以直接 c o n c a t concat concat 到一起。

  2. 根据 Self-Attention 的公式,忽略 d \sqrt{d} d ,注意第1行,已经计算,第2行,需要使用 Q 2 Q_{2} Q2 K 2 K_{2} K2 V 2 V_{2} V2,进行计算,即:
    A t t 2 ( Q , K , V ) = s o f t m a x ( Q K ⊤ + m a s k ) ∗ V s o f t m a x ( [ Q 1 K 1 ⊤ Q 2 K 1 ⊤ + Q 2 K 2 ⊤ ] ) ⋅ [ V 1 V 2 ] = [ s o f t m a x ( Q 1 K 1 ⊤ ) V 1 s o f t m a x ( Q 2 K 1 ⊤ ) V 1 + s o f t m a x ( Q 2 K 2 ⊤ ) V 2 ] = [ A t t 1 ( Q , K , V ) s o f t m a x ( Q 2 K 1 ⊤ ) V 1 + s o f t m a x ( Q 2 K 2 ⊤ ) V 2 ] Att_{2}(Q,K,V) = softmax(QK^{\top}+mask)*V \\ softmax(\left[ \begin{array}{ll} Q_{1}K_{1}^{\top}\\ Q_{2}K_{1}^{\top} + Q_{2}K_{2}^{\top}\\ \end{array} \right]) \cdot \begin{bmatrix} V_{1} \\ V_{2} \\ \end{bmatrix} \\= \left[ \begin{array}{ll} softmax(Q_{1}K_{1}^{\top})V_{1}\\ softmax(Q_{2}K_{1}^{\top})V_{1} + softmax(Q_{2}K_{2}^{\top})V_{2}\\ \end{array} \right] \\ = \left[ \begin{array}{} Att_{1}(Q,K,V) \\ softmax(Q_{2}K_{1}^{\top})V_{1} + softmax(Q_{2}K_{2}^{\top})V_{2}\\ \end{array} \right] Att2(Q,K,V)=softmax(QK+mask)Vsoftmax([Q1K1Q2K1+Q2K2])[V1V2]=[softmax(Q1K1)V1softmax(Q2K1)V1+softmax(Q2K2)V2]=[Att1(Q,K,V)softmax(Q2K1)V1+softmax(Q2K2)V2]

  3. KV 都是成对出现的,如果 缓存 KV,则可以加快推理速度。

第 3 步:重复进行。

3. 缓存占用

关于 Llama3 的 KV Cache 源码,参考 model.py:

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xvkeys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

关于 KV 的缓存内存占用:

相关参数 batch_size=32head=32layer=32dim_size=4096seq_length=2048,float32(4个字节)类,计算 KV cache 的缓存占用:
M = 2 ∗ N b s ∗ ( N d i m / N h e a d ∗ N h e a d ) ∗ N l a y e r ∗ N s e q ∗ 4 = 2 ∗ 32 ∗ 4096 ∗ 32 ∗ 2048 ∗ 4 / 1024 / 1024 / 1024 = 64 G M=2*N_{bs}*(N_{dim}/N_{head}*N_{head})*N_{layer}*N_{seq}*4 \\ =2*32*4096*32*2048*4/1024/1024/1024=64G M=2Nbs(Ndim/NheadNhead)NlayerNseq4=23240963220484/1024/1024/1024=64G
也就是说 head 数量无关,因为维度除以 Head 再乘以 Head。Llama3 使用 GQA (Grouped Query Attention) 分组查询注意力机制,降低 4 倍的 KV Cache,head=32,kv_head=8,即 scale=head/kv_head=4

参考:

  • CSDN - 从头开始实现 LLaMA3 的网络结构与推理流程 教程
  • Transformers KV Caching Explained

这篇关于LLM - GPT(Decoder Only) 类模型的 KV Cache 公式与原理 教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1114735

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

4B参数秒杀GPT-3.5:MiniCPM 3.0惊艳登场!

​ 面壁智能 在 AI 的世界里,总有那么几个时刻让人惊叹不已。面壁智能推出的 MiniCPM 3.0,这个仅有4B参数的"小钢炮",正在以惊人的实力挑战着 GPT-3.5 这个曾经的AI巨人。 MiniCPM 3.0 MiniCPM 3.0 MiniCPM 3.0 目前的主要功能有: 长上下文功能:原生支持 32k 上下文长度,性能完美。我们引入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}