本文主要是介绍【HuggingFace Transformers】LlamaMLP源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
LlamaMLP源码解析
- 1. LlamaMLP 介绍
- 2. LlamaMLP类 源码解析
1. LlamaMLP 介绍
LlamaMLP 是 LLaMA 模型中的 MLP 层,主要用于对输入特征进行非线性变换。在分片预训练模式下,线性层的权重被切分,分步处理后再进行拼接和求和,而在常规模式下,直接应用线性变换和激活函数处理输入数据。其计算公式为:
o u t p u t = W d o w n ⋅ ( σ ( W g a t e ⋅ x + b g a t e ) ⊙ ( W u p ⋅ x + b u p ) ) + b d o w n output = W_{down}\cdot(\sigma(W_{gate}\cdot x+b_{gate})\odot (W_{up}\cdot x+b_{up})) +b_{down} output=Wdown⋅(σ(Wgate⋅x+bgate)⊙(Wup⋅x+bup))+bdown
2. LlamaMLP类 源码解析
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
# -*- coding: utf-8 -*-
# @time: 2024/8/28 15:16import torch
import torch.nn.functional as Ffrom torch import nn
from transformers.activations import ACT2FNclass LlamaMLP(nn.Module):def __init__(self, config):super().__init__()self.config = config # 配置参数self.hidden_size = config.hidden_size # 隐藏层的维度self.intermediate_size = config.intermediate_size # 中间层的维度self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) # 定义第一个线性变换层,将隐藏层映射到中间层self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) # 定义第二个线性变换层,将隐藏层映射到中间层self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) # 定义第三个线性变换层,将中间层的输出映射回隐藏层self.act_fn = ACT2FN[config.hidden_act] # 根据配置选择激活函数def forward(self, x):# 如果是分片预训练if self.config.pretraining_tp > 1:slice = self.intermediate_size // self.config.pretraining_tp # 计算每个切片的大小gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) # 将 gate_proj 层的权重按行切分成多个切片up_proj_slices = self.up_proj.weight.split(slice, dim=0) # 将 up_proj 层的权重按行切分成多个切片down_proj_slices = self.down_proj.weight.split(slice, dim=1) # 将 down_proj 层的权重按列切分成多个切片gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) # 对输入 x 应用每个 gate_proj 切片的线性变换,并沿列拼接up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) # 对输入 x 应用每个 up_proj 切片的线性变换,并沿列拼接intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) # 应用激活函数后,与 up_proj 结果逐元素相乘,并沿通道切分成多个张量down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)] # 对每个 intermediate_states 切片应用对应的 down_proj 切片的线性变换down_proj = sum(down_proj) # 将所有 down_proj 切片的结果相加else:# 如果不是分片预训练,直接进行线性变换和激活函数处理down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))# 返回最终的输出结果return down_proj
这篇关于【HuggingFace Transformers】LlamaMLP源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!