LLama2源码分析——Rotary Position Embedding分析

2024-06-07 00:36

本文主要是介绍LLama2源码分析——Rotary Position Embedding分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)

原理推导参考自上文,以下结合huggingface代码分析公式计算过程

1 旋转角度计算

计算公式如下,其中d为词嵌入维度,这部分和论文原文一样
θ j = 1000 0 − 2 ( j − 1 ) / d , j ∈ [ 1 , 2 , … , d / 2 ] \theta_j=10000^{-2(j-1)/d},j\in [1,2,\ldots,d/2] θj=100002(j1)/d,j[1,2,,d/2]

# 计算词向量元素两两分组之后,每组元素对应的旋转角度
# 维度:[dim / 2]
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

2 计算整个seq的cos_sin矩阵

def _set_cos_sin_cache(self, seq_len, device, dtype):self.max_seq_len_cached = seq_len# 生成token长度序列t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)# 计算两个矩阵的外积,结果维度[seq_len, dim // 2]freqs = torch.einsum("i,j->ij", t, self.inv_freq)# 类似[[0, 2, 4, ..., 0, 2, 4, ...], ...]形式,旋转角度两两一组相同emb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

3 计算旋转式位置编码

f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ \begin{aligned}f_q(x_m,m)&=(W_qx_m)e^{im\theta} \\f_k(x_n,n)&=(W_kx_n)e^{in\theta}\end{aligned} fq(xm,m)fk(xn,n)=(Wqxm)eimθ=(Wkxn)einθ
公式根据欧拉公式转化后为
( q m ( 1 ) + i q m ( 2 ) ) ∗ ( cos ⁡ ( m θ ) + i sin ⁡ ( m θ ) ) (q_{m}^{(1)}+iq_{m}^{(2)})*(\cos(m\theta)+i\sin(m\theta)) (qm(1)+iqm(2))(cos(mθ)+isin(mθ))

展开后将结果重新表示为实数向量即为
q m e i m θ = [ q m ( 1 ) cos ⁡ ( m θ ) − q m ( 2 ) sin ⁡ ( m θ ) , q m ( 2 ) cos ⁡ ( m θ ) + q m ( 1 ) sin ⁡ ( m θ ) ] q_me^{im\theta}=[q_m^{(1)}\cos(m\theta)-q_m^{(2)}\sin(m\theta),q_m^{(2)}\cos(m\theta)+q_m^{(1)}\sin(m\theta)] qmeimθ=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]
key的计算同理,以上公式是2维embedding的旋转编码计算,实际代码中是将高纬度的embedding两两分组按照上述公式计算,同一组内的旋转角度相同,此处Llama代码中的分组计算方式与论文原文有所区别,论文原文中是将embedding_dim维度(最后一维)的向量按照相邻两个位置数字为一组,可以按照如下代码理解

>>> a
tensor([[1, 2, 3, 4, 5, 6, 7, 8],[1, 2, 3, 4, 5, 6, 7, 8]])
>>> a.view(2, -1, 2)
tensor([[[1, 2],[3, 4],[5, 6],[7, 8]],[[1, 2],[3, 4],[5, 6],[7, 8]]])

因此,单个token的位置编码是如下图方式计算
image
但以上的R矩阵比较稀疏,计算时浪费大量算力,因此Llama中采用不同的方式计算

  • Llama源码中计算方法

( q 0 q 1 ⋮ q d / 2 − 1 q d / 2 q d / 2 + 1 ⋮ q d − 1 ) ⊗ ( cos ⁡ m θ 0 cos ⁡ m θ 2 cos ⁡ m θ 4 ⋮ cos ⁡ m θ d − 2 cos ⁡ m θ 0 cos ⁡ m θ 2 ⋮ cos ⁡ m θ d − 2 ) + ( − q d / 2 − q d / 2 + 1 ⋮ − q d − 1 q 1 q 2 ⋮ q d / 2 − 1 ) ⊗ ( sin ⁡ m θ 0 sin ⁡ m θ 2 sin ⁡ m θ 4 ⋮ sin ⁡ m θ d − 2 sin ⁡ m θ 0 sin ⁡ m θ 2 ⋮ sin ⁡ m θ d − 2 ) \begin{pmatrix} {q_0}\\{q_1}\\{\vdots}\\{q_{d/2-1}}\\{q_{d/2}}\\{q_{d/2+1}}\\{\vdots}\\{q_{d-1}} \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_0\\\cos m\theta_2\\\cos m\theta_4\\\vdots\\\cos m\theta_{d-2}\\\cos m\theta_0\\\cos m\theta_2\\\vdots\\\cos m\theta_{d-2} \end{pmatrix} + \begin{pmatrix} {-q_{d/2}}\\{-q_{d/2+1}}\\\vdots\\{-q_{d-1}}\\{q_{1}}\\{q_{2}}\\\vdots\\{q_{d/2-1}} \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_0\\\sin m\theta_2\\\sin m\theta_4\\\vdots\\\sin m\theta_{d-2}\\\sin m\theta_0\\\sin m\theta_2\\\vdots\\\sin m\theta_{d-2} \end{pmatrix} q0q1qd/21qd/2qd/2+1qd1 cosmθ0cosmθ2cosmθ4cosmθd2cosmθ0cosmθ2cosmθd2 + qd/2qd/2+1qd1q1q2qd/21 sinmθ0sinmθ2sinmθ4sinmθd2sinmθ0sinmθ2sinmθd2

def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):cos = cos[position_ids].unsqueeze(unsqueeze_dim)sin = sin[position_ids].unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

这篇关于LLama2源码分析——Rotary Position Embedding分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1037685

相关文章

Springboot中分析SQL性能的两种方式详解

《Springboot中分析SQL性能的两种方式详解》文章介绍了SQL性能分析的两种方式:MyBatis-Plus性能分析插件和p6spy框架,MyBatis-Plus插件配置简单,适用于开发和测试环... 目录SQL性能分析的两种方式:功能介绍实现方式:实现步骤:SQL性能分析的两种方式:功能介绍记录

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操

Redis主从/哨兵机制原理分析

《Redis主从/哨兵机制原理分析》本文介绍了Redis的主从复制和哨兵机制,主从复制实现了数据的热备份和负载均衡,而哨兵机制可以监控Redis集群,实现自动故障转移,哨兵机制通过监控、下线、选举和故... 目录一、主从复制1.1 什么是主从复制1.2 主从复制的作用1.3 主从复制原理1.3.1 全量复制

Redis主从复制的原理分析

《Redis主从复制的原理分析》Redis主从复制通过将数据镜像到多个从节点,实现高可用性和扩展性,主从复制包括初次全量同步和增量同步两个阶段,为优化复制性能,可以采用AOF持久化、调整复制超时时间、... 目录Redis主从复制的原理主从复制概述配置主从复制数据同步过程复制一致性与延迟故障转移机制监控与维

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

Redis主从复制实现原理分析

《Redis主从复制实现原理分析》Redis主从复制通过Sync和CommandPropagate阶段实现数据同步,2.8版本后引入Psync指令,根据复制偏移量进行全量或部分同步,优化了数据传输效率... 目录Redis主DodMIK从复制实现原理实现原理Psync: 2.8版本后总结Redis主从复制实

锐捷和腾达哪个好? 两个品牌路由器对比分析

《锐捷和腾达哪个好?两个品牌路由器对比分析》在选择路由器时,Tenda和锐捷都是备受关注的品牌,各自有独特的产品特点和市场定位,选择哪个品牌的路由器更合适,实际上取决于你的具体需求和使用场景,我们从... 在选购路由器时,锐捷和腾达都是市场上备受关注的品牌,但它们的定位和特点却有所不同。锐捷更偏向企业级和专