LLM:Sinusoidal位置编码

2024-01-19 19:20
文章标签 位置 llm 编码 sinusoidal

本文主要是介绍LLM:Sinusoidal位置编码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1:什么是大模型的外推性?

外推性是指大模型在训练时和预测时的输入长度不一致,导致模型的泛化能力下降的问题。例如,如果一个模型在训练时只使用了512个 token 的文本,那么在预测时如果输入超过512个 token,模型可能无法正确处理。这就限制了大模型在处理长文本或多轮对话等任务时的效果。

2:为什么要位置编码PE?

Transformer结构:并行输入所以需要让输入的内容具有一定的位置信息。

句子1:我喜欢吃洋葱

句子2:洋葱喜欢吃我

3:绝对位置编码:

训练式位置编码模型只能感知到每个词向量所处的绝对位置,无法感知词向量之间的相对位置。广泛应用于早期的transformer类型的模型,如BERT、GPT、ALBERT等。但其缺点是模型不具有长度外推性。因为位置编码矩阵的大小是预设的,若对其进行扩展,将会破坏模型在预训练阶段学习到的位置信息。例如将512*768扩展为1024*768,新拓展的512个位置向量缺乏训练,无法正确表示512~1023的位置信息。

Sinusoidal位置编码:这一点得到了缓解,模型一定程度上能够感知相对位置。Sinusoidal位置编码的每个分量都是正弦或余弦函数,所有每个分量的数值都具有周期性,并且越靠后的分量,波长越长,频率越低。

Sinusoidal位置编码还具有远程衰减的性质具体表现为:对于两个相同的词向量,如果它们之间的距离越近,则他们的内积分数越高,反之则越低。如下图所示,我们随机初始化两个向量q和k,将q固定在位置0上,k的位置从0开始逐步变大,依次计算q和k之间的内积。我们发现随着q和k的相对距离的增加,它们之间的内积分数震荡衰减。

图片

因为Sinusoidal位置编码中的正弦余弦函数具备周期性,并且具备远程衰减的特性,所以理论上也具备一定长度外推的能力。

4:Sinusoidal位置编码

PE:表示位置编码

pos:表示当前字符在输入sequence中的位置

d_{model}:表示该字符嵌入的维度。 

偶数位置使用sin, 奇数位置使用cos

举例:假设每个词嵌入维度为512,如图所示:

[我, 喜,欢,你]  <-----输入sequence

[0,     1,      2,     3]      <-----对应位置    


注:如果sequence长度不够,那么不足就直接使用padding用0填充。

这里以“爱”为例,pos = 1,来说明PE的计算:

计算完所有的PE后,将词嵌入与PE进行相加,即可得到带有位置信息的embedding。 

ps:这里有一个小trick:
当emb和位置编码相加了之后,我们希望emb占多数,比如将emb放大10倍,那么在相加后的张 量里,emb就会占大部分。因为主要的语义信息是蕴含在emb当中的,我们希望位置编码带来的影响不要超过emb。所以对 emb进行了缩放再和位置编码相加。

 

python 代码1如下:更加直观

# position 就对应 token 序列中的位置索引 i
# hidden_dim 就对应词嵌入维度大小 d
# seq_len 表示 token 序列长度
def get_position_angle_vec(position):return [position / np.power(10000, 2 * (hid_j // 2) / hidden_dim) for hid_j in range(hidden_dim)]# position_angle_vecs.shape = [seq_len, hidden_dim]
position_angle_vecs = np.array([get_position_angle_vec(pos_i) for pos_i in range(seq_len)])# 分别计算奇偶索引位置对应的 sin 和 cos 值
position_angle_vecs[:, 0::2] = np.sin(position_angle_vecs[:, 0::2])  # dim 2t
position_angle_vecs[:, 1::2] = np.cos(position_angle_vecs[:, 1::2])  # dim 2t+1# positional_embeddings.shape = [1, seq_len, hidden_dim]
positional_embeddings = torch.FloatTensor(position_angle_vecs).unsqueeze(0)

python 代码2如下:

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):# (max_len, 1)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)# (output_dim//2)ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]theta = torch.pow(10000, -2 * ids / output_dim)# (max_len, output_dim//2)embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))# (max_len, output_dim//2, 2)embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)# (bs, head, max_len, output_dim//2, 2)embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复# (bs, head, max_len, output_dim)# reshape后就是:偶数sin, 奇数cos了embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))embeddings = embeddings.to(device)return embeddings

参考:

https://kaiyuan.blog.csdn.net/article/details/119621613

https://blog.csdn.net/qq_41915623/article/details/125166309

https://zhuanlan.zhihu.com/p/352233973

https://blog.csdn.net/u013853733/article/details/107853989

https://spaces.ac.cn/archives/8130

这篇关于LLM:Sinusoidal位置编码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/623426

相关文章

POJ1269 判断2条直线的位置关系

题目大意:给两个点能够确定一条直线,题目给出两条直线(由4个点确定),要求判断出这两条直线的关系:平行,同线,相交。如果相交还要求出交点坐标。 解题思路: 先判断两条直线p1p2, q1q2是否共线, 如果不是,再判断 直线 是否平行, 如果还不是, 则两直线相交。  判断共线:  p1p2q1 共线 且 p1p2q2 共线 ,共线用叉乘为 0  来判断,  判断 平行:  p1p

C++ | Leetcode C++题解之第393题UTF-8编码验证

题目: 题解: class Solution {public:static const int MASK1 = 1 << 7;static const int MASK2 = (1 << 7) + (1 << 6);bool isValid(int num) {return (num & MASK2) == MASK1;}int getBytes(int num) {if ((num &

C语言 | Leetcode C语言题解之第393题UTF-8编码验证

题目: 题解: static const int MASK1 = 1 << 7;static const int MASK2 = (1 << 7) + (1 << 6);bool isValid(int num) {return (num & MASK2) == MASK1;}int getBytes(int num) {if ((num & MASK1) == 0) {return

form表单提交编码的问题

浏览器在form提交后,会生成一个HTTP的头部信息"content-type",标准规定其形式为Content-type: application/x-www-form-urlencoded; charset=UTF-8        那么我们如果需要修改编码,不使用默认的,那么可以如下这样操作修改编码,来满足需求: hmtl代码:   <meta http-equiv="Conte

Linux Centos 迁移Mysql 数据位置

转自:http://www.tuicool.com/articles/zmqIn2 由于业务量增加导致安装在系统盘(20G)磁盘空间被占满了, 现在进行数据库的迁移. Mysql 是通过 yum 安装的. Centos6.5Mysql5.1 yum 安装的 mysql 服务 查看 mysql 的安装路径 执行查询 SQL show variables like

PDFQFZ高效定制:印章位置、大小随心所欲

前言 在科技编织的快节奏时代,我们不仅追求速度,更追求质量,让每一分努力都转化为生活的甜蜜果实——正是在这样的背景下,一款名为PDFQFZ-PDF的实用软件应运而生,它以其独特的功能和高效的处理能力,在PDF文档处理领域脱颖而出。 它的开发,源自于对现代办公效率提升的迫切需求。在数字化办公日益普及的今天,PDF作为一种跨平台、不易被篡改的文档格式,被广泛应用于合同签署、报告提交、证书打印等各个

[论文笔记]LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale

引言 今天带来第一篇量化论文LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale笔记。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 大语言模型已被广泛采用,但推理时需要大量的GPU内存。我们开发了一种Int8矩阵乘法的过程,用于Transformer中的前馈和注意力投影层,这可以将推理所需

LLM系列 | 38:解读阿里开源语音多模态模型Qwen2-Audio

引言 模型概述 模型架构 训练方法 性能评估 实战演示 总结 引言 金山挂月窥禅径,沙鸟听经恋法门。 小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖铁观音的小男孩,今天这篇小作文主要是介绍阿里巴巴的语音多模态大模型Qwen2-Audio。近日,阿里巴巴Qwen团队发布了最新的大规模音频-语言模型Qwen2-Audio及其技术报告。该模型在音频理解和多模态交互

4-4.Andorid Camera 之简化编码模板(获取摄像头 ID、选择最优预览尺寸)

一、Camera 简化思路 在 Camera 的开发中,其实我们通常只关注打开相机、图像预览和关闭相机,其他的步骤我们不应该花费太多的精力 为此,应该提供一个工具类,它有处理相机的一些基本工具方法,包括获取摄像头 ID、选择最优预览尺寸以及打印相机参数信息 二、Camera 工具类 CameraIdResult.java public class CameraIdResult {

Python字符编码及应用

字符集概念 字符集就是一套文字符号及其编码的描述。从第一个计算机字符集ASCII开始,为了处理不同的文字,发明过几百种字符集,例如ASCII、USC、GBK、BIG5等,这些不同的字符集从收录到编码都各不相同。在编程中出现比较严重的问题是字符乱码。 几个概念 位:计算机的最小单位二进制中的一位,用二进制的0,1表示。 字节:八位组成一个字节。(位与字节有对应关系) 字符:我们肉眼可见的文字与符号。