本文主要是介绍LLM长度外推——位置插值(llama/baichuan),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
位置插值(position Interpolation, PI)通过将超出训练长度的位置索引等比例缩小,映射到模型已经学习的位置范围内,实现长度外推。
好处是不用重新训练,直接在推理时加入。
llama的实现方式
论文提出 Extending Context Window of Large Language Models via Positional Interpolation
llama采用Rope位置编码,因此其实现都是针对rope编码的位置插值。
官方实现的代码:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""def forward(self, x, position_ids):# difference to the original RoPE: a scaling factor is aplied to the position idsposition_ids = position_ids.float() / self.scaling_factorcos, sin = super().forward(x, position_ids)return cos, sin
Super-HOT项目的实现
位置插值原理介绍: https://kaiokendev.github.io/til#extending-context-to-8k
源代码:https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py
class ScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)max_position_embeddings = 8192# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype,)# These two lines:self.scale = 1 / 4t *= self.scale
参考:
1.https://zhuanlan.zhihu.com/p/679147878
2.https://blog.csdn.net/v_JULY_v/article/details/135072211
3.https://kaiokendev.github.io/til#extending-context-to-8k
百川的实现方式
百川13B的位置编码是Alibi。因此是针对Alibi的长度外推。
有测试表明外推最大长度大约是训练的8倍时可以达到最佳性能:评论区
实现代码和步骤:
https://github.com/seanzhang-zhichen/baichuan-Dynamic-NTK-ALiBi
参考:
1.https://zhuanlan.zhihu.com/p/657161287
2.https://zhuanlan.zhihu.com/p/647628295
这篇关于LLM长度外推——位置插值(llama/baichuan)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!