LLaMA代码笔记 --基于lit-llama

2024-08-26 18:44
文章标签 代码 笔记 llama lit

本文主要是介绍LLaMA代码笔记 --基于lit-llama,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码来自:lit-llama
modelscope模型下载 :llama-7b
下载后的模型需要转换为lit-llama使用的格式,详见 howto 文件夹下的 download_weights.md

文中代码为了方便说明,删减了一些内容,详细代码请查看源码。

generate

输入参数:

  • idx: 输入的prompt经过 tokenizer encode之后输出的序列tensor.使用了默认的输入,token长度为6。
  • max_new_tokens: 每次新生成的最大token数
  • max_seq_length: 输入的序列最大长度.
  • temperature: 温度越高,结果越多样性;温度越低,确定性越高。
  • top_k: 默认为200。topk越大,结果越多样性;topk越小,结果确定性越高。
@torch.no_grad()
def generate(model: LLaMA,idx: torch.Tensor,max_new_tokens: int,*,max_seq_length: Optional[int] = None,temperature: float = 1.0,top_k: Optional[int] = None,eos_id: Optional[int] = None,
) -> torch.Tensor:# create an empty tensor of the expected final shape and fill in the current tokensT = idx.size(0)T_new = T + max_new_tokensif max_seq_length is None:max_seq_length = min(T_new, model.config.block_size)device, dtype = idx.device, idx.dtype# 创建了一个空的tensor,包括输入的idx,加上允许生成的最大tokens 数,定义了最终结果变量empty = torch.empty(T_new, dtype=dtype, device=device)empty[:T] = idxidx = emptyinput_pos = torch.arange(0, T, device=device) #指明输入的数据在idx中的pos。# generate max_new_tokens tokensfor _ in range(max_new_tokens):x = idx.index_select(0, input_pos).view(1, -1) #在结果变量idx中使用input_pos 取出当前输入。# forwardlogits = model(x, max_seq_length, input_pos)   #(1,seq,32000)logits = logits[0, -1] / temperature# optionally crop the logits to only the top k optionsif top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits = torch.where(logits < v[[-1]], -float("Inf"), logits)probs = torch.nn.functional.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) #多项式采样# advanceinput_pos = input_pos[-1:] + 1#下一个输入的pos# concatenate the new generationidx = idx.index_copy(0, input_pos, idx_next) #把生成结果copy到结果变量中# if <eos> token is triggered, return the output (stop generation)if idx_next == eos_id:return idx[:input_pos]  # include the EOS tokenreturn idx

LLAMA Model

使用默认配置,7B模型。block_size 定义了rope和mask的大小,自然也限制了最大输入长度,超过了block_size的输入,无法取得位置编码和mask。

@dataclass
class LLaMAConfig:block_size: int = 2048vocab_size: int = 32000padded_vocab_size: Optional[int] = Nonen_layer: int = 32n_head: int = 32n_embd: int = 4096def __post_init__(self):if self.padded_vocab_size is None:self.padded_vocab_size = find_multiple(self.vocab_size, 64)@classmethoddef from_name(cls, name: str) -> Self:return cls(**llama_configs[name])llama_configs = {"7B": dict(n_layer=32, n_head=32, n_embd=4096),"13B": dict(n_layer=40, n_head=40, n_embd=5120),"30B": dict(n_layer=60, n_head=52, n_embd=6656),"65B": dict(n_layer=80, n_head=64, n_embd=8192),
}

LLaMA模型主要有多层attention模块构成。
预测第一个token的时候,需要创建 build_rope_cache 和 build_mask_cache,以及 kv_caches。
然后从rope_cache 和mask_cache 根据 input_pos 取出对应位置的值。
kv_caches 即 缓存模型中所有层的kv值,7B有32层,则 kv_caches 的长度为32.
kv的shape为(B, self.config.n_head, max_seq_length, head_size),使用torch.zeros 初始化。
逐层运行,并将每层的kv值保存在kv_caches 中。
这里每次输入的长度肯定是小于max_seq_length,也就是只更新相应index的kv_caches中的值。
最后经过RMSNorm后,经过线性层,输出每个vocab的概率.

class LLaMA(nn.Module):def __init__(self, config: LLaMAConfig) -> None:super().__init__()assert config.padded_vocab_size is not Noneself.config = configself.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)self.transformer = nn.ModuleDict(dict(wte=nn.Embedding(config.padded_vocab_size, config.n_embd),h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),ln_f=RMSNorm(config.n_embd),))self.rope_cache: Optional[RoPECache] = Noneself.mask_cache: Optional[MaskCache] = Noneself.kv_caches: List[KVCache] = []def forward(self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:B, T = idx.size()block_size = self.config.block_sizeif max_seq_length is None:max_seq_length = block_sizeassert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"if self.rope_cache is None:self.rope_cache = self.build_rope_cache(idx)if self.mask_cache is None:self.mask_cache = self.build_mask_cache(idx)#从rope_cache 和 mask_cache 取出对应位置的rope和maskif input_pos is not None:rope = self.rope_cache.index_select(0, input_pos) #(6,64,2),(1,64,2)mask = self.mask_cache.index_select(2, input_pos) #(1,1,6,2048),(1,1,1,2048)mask = mask[:, :, :, :max_seq_length] #1,1,6,56),(1,1,1,56)else:#未给出input_pos,则根据输入长度rope = self.rope_cache[:T]mask = self.mask_cache[:, :, :T, :T]# embeddingsx = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd) #(1,1,4096)if input_pos is None:  # proxy for use_cache=Falsefor block in self.transformer.h:x, _ = block(x, rope, mask, max_seq_length)else:if not self.kv_caches: #创建kv_cacheshead_size = self.config.n_embd // self.config.n_head  #128cache_shape = (B, self.config.n_head, max_seq_length, head_size) #(1,32,56,128)self.kv_caches = [(torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))for _ in range(self.config.n_layer)]for i, block in enumerate(self.transformer.h):x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])#RMSNormx = self.transformer.ln_f(x) #(1,6,4096)logits = self.lm_head(x)  # (b, t, vocab_size) (1,6,32000)return logitsdef build_rope_cache(self, idx: torch.Tensor) -> RoPECache:return build_rope_cache(seq_len=self.config.block_size,n_elem=self.config.n_embd // self.config.n_head,dtype=idx.dtype,device=idx.device,)# mask_cache 的shape为 (block_size,block_size),右上角为False。def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)return torch.tril(ones).unsqueeze(0).unsqueeze(0)

build_rope_cache

rope 按照下面的计算方法计算,有很多的shape转换,可以使用较小的维度对照公式逐步查看。
参考:一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long(含NTK-aware简介)
rope_cache 的shape为 (2048,64,2),2048是模型定义的block_size,64 为 attention中每个head 的dim 再除以2。

旋转角度计算公式
在这里插入图片描述

def build_rope_cache( #2048,128=4096/32seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
) -> RoPECache:# 上面的角度计算公式,2(i-1),i从1到d/2,就等于torch.arange(0, d, 2)theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))  #(64)# Create position indexes `[0, 1, ..., seq_len - 1]`seq_idx = torch.arange(seq_len, dtype=dtype, device=device)  #(2048)# Calculate the product of position index and $\theta_i$#每个角度都乘以indexidx_theta = torch.outer(seq_idx, theta).float()  #(2048,64)cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) #(2048,64,2)# this is to mimic the behaviour of complex32, else we will get different resultsif dtype in (torch.float16, torch.bfloat16, torch.int8):cache = cache.half()return cache

在这里插入图片描述
在attention中,只有q和k需要添加位置信息,按照上图来计算。

def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:# truncate to support variable sizesT = x.size(1) #(1,6,32,128)rope_cache = rope_cache[:T] #(6,64,2)# cast because the reference doesxshaped = x.float().reshape(*x.shape[:-1], -1, 2) #(1,6,32,64,2)rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)  #(1,6,1,64,2)x_out2 = torch.stack([xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],],-1,)  #(1,6,32,64,2)x_out2 = x_out2.flatten(3) #(1,6,32,128)return x_out2.type_as(x)

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/0eff924a1239455e906050c003365315.png
图片是llama2的结构图,llama1并没有使用GQA,其他结构是一样的。

先RMSNorm,再attention,残差相加,然后再RMSNorm,MLP,再次残差相加。

class Block(nn.Module):def __init__(self, config: LLaMAConfig) -> None:super().__init__()self.rms_1 = RMSNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.rms_2 = RMSNorm(config.n_embd)self.mlp = MLP(config)def forward(self,x: torch.Tensor,rope: RoPECache,mask: MaskCache,max_seq_length: int,input_pos: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,) -> Tuple[torch.Tensor, Optional[KVCache]]:h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)x = x + hx = x + self.mlp(self.rms_2(x))return x, new_kv_cache

在attention部分中,q和k添加了rope。
使用了kv_cache,kv_cache的初始值都是 0,这里就需要把计算出的k 和 v copy到 kv_cache中对应的index位置。
在generate的for循环中,第一次输入全部的prompt,假设长度为 6;第二次只输入生成的token,长度为1,也就是说第二次以后,每次的x size都是(1,1,4096),因为之前的kv值都已经存在kv_cache中了。
如果input_pos >= max_seq_length,cache_k 和cache_v 就要左移,丢弃最早的kv值。

class CausalSelfAttention(nn.Module):def __init__(self, config: LLaMAConfig) -> None:super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)self.n_head = config.n_headself.n_embd = config.n_embdself.block_size = config.block_sizedef forward(self,x: torch.Tensor,rope: RoPECache,mask: MaskCache,max_seq_length: int,input_pos: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,) -> Tuple[torch.Tensor, Optional[KVCache]]:B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd) #(1,6,4096)# calculate query, key, values for all heads in batch and move head forward to be the batch dimq, k, v = self.c_attn(x).split(self.n_embd, dim=2) #(1,6,4096)head_size = C // self.n_head #128k = k.view(B, T, self.n_head, head_size) #(1,6,32,128)q = q.view(B, T, self.n_head, head_size)v = v.view(B, T, self.n_head, head_size)q = apply_rope(q, rope)k = apply_rope(k, rope)k = k.transpose(1, 2)  # (B, nh, T, hs) (1,32,6,128)q = q.transpose(1, 2)  # (B, nh, T, hs)v = v.transpose(1, 2)  # (B, nh, T, hs)if kv_cache is not None:cache_k, cache_v = kv_cache #(1,32,56,128),(1,32,56,128)# check if reached token limitif input_pos[-1] >= max_seq_length:input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)# 左移,丢弃最早的kv值cache_k = torch.roll(cache_k, -1, dims=2)cache_v = torch.roll(cache_v, -1, dims=2)k = cache_k.index_copy(2, input_pos, k) #(1,32,56,128)v = cache_v.index_copy(2, input_pos, v) #(1,32,56,128)kv_cache = k, v# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)#  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))#  att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))#  att = F.softmax(att, dim=-1)#  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)#这里使用了kv_cache后,kv的shape和q不再相同,#(1,32,6,128),(1,32,56,128),(1,32,56,128) => (1,32,6,128)y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side #(1,6,4096)# output projectiony = self.c_proj(y)return y, kv_cache

MLP使用了SwiGLU函数,n_hidden也是个奇怪的数。
Llama改进之——SwiGLU激活函数

class MLP(nn.Module):def __init__(self, config: LLaMAConfig) -> None:super().__init__()hidden_dim = 4 * config.n_embdn_hidden = int(2 * hidden_dim / 3)n_hidden = find_multiple(n_hidden, 256)self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)def forward(self, x: torch.Tensor) -> torch.Tensor:x = F.silu(self.c_fc1(x)) * self.c_fc2(x)x = self.c_proj(x)return x

参考:
LLaMA的解读与其微调

这篇关于LLaMA代码笔记 --基于lit-llama的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1109372

相关文章

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Java实现自定义table宽高的示例代码

《Java实现自定义table宽高的示例代码》在桌面应用、管理系统乃至报表工具中,表格(JTable)作为最常用的数据展示组件,不仅承载对数据的增删改查,还需要配合布局与视觉需求,而JavaSwing... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

Go语言代码格式化的技巧分享

《Go语言代码格式化的技巧分享》在Go语言的开发过程中,代码格式化是一个看似细微却至关重要的环节,良好的代码格式化不仅能提升代码的可读性,还能促进团队协作,减少因代码风格差异引发的问题,Go在代码格式... 目录一、Go 语言代码格式化的重要性二、Go 语言代码格式化工具:gofmt 与 go fmt(一)

HTML5实现的移动端购物车自动结算功能示例代码

《HTML5实现的移动端购物车自动结算功能示例代码》本文介绍HTML5实现移动端购物车自动结算,通过WebStorage、事件监听、DOM操作等技术,确保实时更新与数据同步,优化性能及无障碍性,提升用... 目录1. 移动端购物车自动结算概述2. 数据存储与状态保存机制2.1 浏览器端的数据存储方式2.1.

基于 HTML5 Canvas 实现图片旋转与下载功能(完整代码展示)

《基于HTML5Canvas实现图片旋转与下载功能(完整代码展示)》本文将深入剖析一段基于HTML5Canvas的代码,该代码实现了图片的旋转(90度和180度)以及旋转后图片的下载... 目录一、引言二、html 结构分析三、css 样式分析四、JavaScript 功能实现一、引言在 Web 开发中,

Python如何去除图片干扰代码示例

《Python如何去除图片干扰代码示例》图片降噪是一个广泛应用于图像处理的技术,可以提高图像质量和相关应用的效果,:本文主要介绍Python如何去除图片干扰的相关资料,文中通过代码介绍的非常详细,... 目录一、噪声去除1. 高斯噪声(像素值正态分布扰动)2. 椒盐噪声(随机黑白像素点)3. 复杂噪声(如伪

Java Spring ApplicationEvent 代码示例解析

《JavaSpringApplicationEvent代码示例解析》本文解析了Spring事件机制,涵盖核心概念(发布-订阅/观察者模式)、代码实现(事件定义、发布、监听)及高级应用(异步处理、... 目录一、Spring 事件机制核心概念1. 事件驱动架构模型2. 核心组件二、代码示例解析1. 事件定义

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部