
2024-05-09 02:44
文章标签 源码 解读 结构 demo mamba



  • 前言
  • 一、mamba结构构建辅助函数解读
    • 1、@dataclass方法解读
    • 2、Norm归一化
      • LayerNorm
      • RMSNorm
      • RMSNorm源码
    • 3、nn.Parameter方法解读
  • 二、mamba原理
  • 二、mamba模型构建
    • 1、主函数入口源码解读
    • 2、Mamba类源码解读
  • 三、ResidualBlock的mamba结构源码解读
  • 四、MambaBlock构成ResidualBlock模块源码解读
    • 1、线性结构(获得x与res)
    • 2、1维卷积结构(x加工)
    • 3、激活结构(x加工)
    • 4、ssm结构(x加工)
    • 5、激活与连接(x与res加工)
    • 6、线性结构(x与res结合后的加工)
  • 五、MambaBlock构成ResidualBlock模块源码解读
    • 1、ssm参数初始化
    • 2、ssm结构
  • 六、完整代码Demo





@dataclass 是一个Python装饰器,用于简化创建数据类(data class)的过程。数据类是一种用于存储数据的特殊类,它自动为你的类添加一些特殊方法,如 initrepreq 等,从而使你可以更轻松地创建和操作数据对象。

使用 @dataclass 装饰器可以自动为类添加一些标准方法,而无需手动编写这些方法。以下是 @dataclass 的一些主要特性:

自动生成 init 方法:@dataclass 装饰器会自动为类生成 init 方法,从而简化实例化对象时的参数传递。

自动生成 repr 方法:@dataclass 装饰器会自动为类生成 repr 方法,以便在打印对象时提供有用的信息。

自动生成 eq 方法:@dataclass 装饰器会自动为类生成 eq 方法,用于比较两个对象是否相等。

自动生成 hash 方法:如果需要将对象用作字典的键或集合的成员,@dataclass 装饰器会自动为类生成 hash 方法。

自动生成 str 方法:@dataclass 装饰器会自动为类生成 str 方法,用于返回对象的字符串表示形式。

以下是一个简单的示例,展示了如何使用 @dataclass 创建一个数据类:

from dataclasses import dataclass# 使用 @dataclass 装饰器创建数据类
class Point:x: inty: int# 创建 Point 对象
p = Point(3, 4)# 打印对象信息
print(p)  # 输出: Point(x=3, y=4)

在这个示例中,我们使用 @dataclass 装饰器创建了一个名为 Point 的数据类,它具有属性 x 和 y。通过使用装饰器,我们不必手动编写 initrepr 等方法,这些方法会被自动生成。当我们实例化一个 Point 对象并打印它时,会得到一个带有属性值的字符串表示形式。










class RMSNorm(nn.Module):def __init__(self,d_model: int,eps: float = 1e-5):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(d_model))def forward(self, x):output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weightreturn output


nn.Parameter 是 PyTorch 中的一个类,用于将张量(tensor)包装成模型参数,使其能够被优化器训练。通过将张量包装成 nn.Parameter,PyTorch 将自动跟踪此参数的梯度,并在反向传播过程中更新参数的数值。


nn.Parameter 是 torch.nn.Parameter 类的实例,它继承自 torch.Tensor 类。
当你将一个张量包装成 nn.Parameter 时,这个张量就会被标记为模型参数,可以在模型的参数列表中被访问和优化。
通过将张量包装成 nn.Parameter,你可以方便地定义模型参数,并在训练过程中更新这些参数的数值。

下面是 nn.Parameter 的详细解释和一个简单的示例演示如何使用它:

import torch
import torch.nn as nn# 创建一个普通张量
tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)# 将张量包装成 nn.Parameter
param = nn.Parameter(tensor)# 打印 nn.Parameter 对象
print(param)# 访问 nn.Parameter 的梯度属性
print("Gradient:", param.grad)# 访问 nn.Parameter 的数据属性

在这个示例中,我们首先创建了一个普通的张量 tensor,然后将其包装成 nn.Parameter 类型的对象 param。我们展示了如何打印 nn.Parameter 对象、访问其梯度属性和数据属性。请注意,只有 nn.Parameter 类型的对象才会在反向传播过程中跟踪梯度并更新参数值。









if __name__ == '__main__':# 创建一个简单的Mamba模型实例vocab_size = 32000n_layer = 2d_model = 128model_args = ModelArgs(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size)model_args.__post_init__()mamba_model = Mamba(model_args)# 生成随机整数张量,元素范围在1到999之间, 输入为batch,length分别表示批量,一个句子长度,每个词对应索引input_data = torch.randint(low=1, high=vocab_size, size=(2, 200))output = mamba_model(input_data)print(output.shape)


这里,我们构建了一个mamba模型,实际构建mamba结构是ResidualBlock模块。没错,我们构建一个类似残差结构的mamba结构。随后,我们看到forward函数,可看出输入经过embedding后将其使用d_model维度表达,变成B L D结构。然后在经过layer结构,每次输出均为B L D结构数据,这个就是mamba模块加工模型。最后经过一个RMSNorm结构,在经过lm_head结构,即完成词的预测。具体代码如下:

class Mamba(nn.Module):def __init__(self, args: ModelArgs):"""Full Mamba model."""super().__init__()self.args = argsself.embedding = nn.Embedding(args.vocab_size, args.d_model)self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])self.norm_f = RMSNorm(args.d_model)self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.# See "Weight Tying" paperdef forward(self, input_ids):x = self.embedding(input_ids)for layer in self.layers:x = layer(x)x = self.norm_f(x)logits = self.lm_head(x)return logits


这个就是每一层结构,我们可以看出输入为(b, l, d),输出也为(b, l, d)结构,只是进行了特征提取,而不改变数据shape。同时,我们也看到这里使用了RMSNorm方法进行归一化的。

class ResidualBlock(nn.Module):def __init__(self, args: ModelArgs):"""Simple block wrapping Mamba block with normalization and residual connection."""super().__init__()self.args = argsself.mixer = MambaBlock(args)self.norm = RMSNorm(args.d_model)def forward(self, x):"""Args:x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:output: shape (b, l, d)"""output = self.mixer(self.norm(x)) + xreturn output

在这个forward中,我们可知是一个类似残差的方法结构,x会做norm归一化后,再进行self.mixer结构(即使mamba方法),使用self.mixer(self.norm(x))此代码。接下来,我将介绍self.mixer = MambaBlock(args)结构。



首先将输入x为(b, l, d)通过self.in_proj将其转换(b, l, 2 * d_in),也就是下图有圆圈①的结构。当然也可以分别使用对x进行,但这里直接一起使用,在通过x_and_res.split方法划分。其中res就是下图右边,x就是下图左边模块。

(b, l, d) = x.shape  # 获得x形状x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in) # 在这里走了一个线性,将d变成2*d_in(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)  # 这里将2*d_in变成2个d_in,分别为x与res# x与res 都是[b,l,d_in]



x = rearrange(x, 'b l d_in -> b d_in l')  # 更换l与d_in
x = self.conv1d(x)[:, :, :l]  # 1维卷积
x = rearrange(x, 'b d_in l -> b l d_in')  # 变回来



x = F.silu(x)  # 使用silu激活函数



y = self.ssm(x)



y = y * F.silu(res)



output = self.out_proj(y)







A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(args.d_inner))




    def ssm(self, x):"""Runs the SSM. See:- Algorithm 2 in Section 3.2 in the Mamba paper [1]- run_SSM(A, B, C, u) in The Annotated S4 [2]Args:x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:output: shape (b, l, d_in)Official Implementation:mamba_inner_ref(),"""(d_in, n) = self.A_log.shape# Compute ∆ A B C D, the state space parameters.#     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)#     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,#                                  and is why Mamba is called **selective** state spaces)A = -torch.exp(self.A_log.float())  # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n],dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return ydef selective_scan(self, u, delta, A, B, C, D):"""Does selective scan algorithm. See:- Section 2 State Space Models in the Mamba paper [1]- Algorithm 2 in Section 3.2 in the Mamba paper [1]- run_SSM(A, B, C, u) in The Annotated S4 [2]This is the classic discrete state space formula:x(t + 1) = Ax(t) + Bu(t)y(t)     = Cx(t) + Du(t)except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).Args:u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)delta: shape (b, l, d_in)A: shape (d_in, n)B: shape (b, l, n)C: shape (b, l, n)D: shape (d_in,)Returns:output: shape (b, l, d_in)Official Implementation:selective_scan_ref(), I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly."""(b, l, d_in) = u.shapen = A.shape[1]# Discretize continuous parameters (A, B)# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:#   "A is the more important term and the performance doesn't change much with the simplification on B"deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])# Note that the below is sequential, while the official implementation does a much faster parallel scan that# is additionally hardware-aware (like FlashAttention).x = torch.zeros((b, d_in, n), device=deltaA.device)ys = []for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1)  # shape (b, l, d_in)y = y + u * Dreturn y



"""Simple, minimal implementation of Mamba in one file of PyTorch.Suggest reading the following before/while reading the code:[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)[2] The Annotated S4 (Sasha Rush and Sidd Karamcheti) batch size                       (`B` in Mamba paper [1] Algorithm 2)l: sequence length                  (`L` in [1] Algorithm 2)d or d_model: hidden dimn or d_state: latent state dim      (`N` in [1] Algorithm 2)expand: expansion factor            (`E` in [1] Section 3.4)d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)A, B, C, D: state space parameters  (See any state space representation formula)(B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)Δ or delta: input-dependent step sizedt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")"""
from __future__ import annotations
import mathimport torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsumfrom typing import Union@dataclass
class ModelArgs:d_model: intn_layer: intvocab_size: intd_state: int = 16expand: int = 2dt_rank: Union[int, str] = 'auto'd_conv: int = 4pad_vocab_size_multiple: int = 8conv_bias: bool = Truebias: bool = Falsedef __post_init__(self):self.d_inner = int(self.expand * self.d_model)if self.dt_rank == 'auto':self.dt_rank = math.ceil(self.d_model / 16)if self.vocab_size % self.pad_vocab_size_multiple != 0:self.vocab_size += (self.pad_vocab_size_multiple- self.vocab_size % self.pad_vocab_size_multiple)class Mamba(nn.Module):def __init__(self, args: ModelArgs):"""Full Mamba model."""super().__init__()self.args = argsself.embedding = nn.Embedding(args.vocab_size, args.d_model)self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])self.norm_f = RMSNorm(args.d_model)self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.# See "Weight Tying" paperdef forward(self, input_ids):"""Args:input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:logits: shape (b, l, vocab_size)Official Implementation:class MambaLMHeadModel,"""x = self.embedding(input_ids)for layer in self.layers:x = layer(x)x = self.norm_f(x)logits = self.lm_head(x)return logitsclass ResidualBlock(nn.Module):def __init__(self, args: ModelArgs):"""Simple block wrapping Mamba block with normalization and residual connection."""super().__init__()self.args = argsself.mixer = MambaBlock(args)self.norm = RMSNorm(args.d_model)def forward(self, x):"""Args:x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:output: shape (b, l, d)Official Implementation:Block.forward(), the official repo chains residual blocks that look like[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...where the first Add is a no-op. This is purely for performance reasons as thisallows them to fuse the Add->Norm.We instead implement our blocks as the more familiar, simpler, and numerically equivalent[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ...."""output = self.mixer(self.norm(x)) + xreturn outputclass MambaBlock(nn.Module):def __init__(self, args: ModelArgs):"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""super().__init__()self.args = argsself.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)self.conv1d = nn.Conv1d(in_channels=args.d_inner,out_channels=args.d_inner,bias=args.conv_bias,kernel_size=args.d_conv,groups=args.d_inner,padding=args.d_conv - 1,)# x_proj takes in `x` and outputs the input-specific Δ, B, Cself.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)# dt_proj projects Δ from dt_rank to d_inself.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)self.A_log = nn.Parameter(torch.log(A))self.D = nn.Parameter(torch.ones(args.d_inner))self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)def forward(self, x):"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].Args:x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:output: shape (b, l, d)Official Implementation:class Mamba,,"""(b, l, d) = x.shape  # 获得x形状x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in) # 在这里走了一个线性,将d变成2*d_in(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)  # 这里将2*d_in变成2个d_in,分别为x与res# x与res 都是[b,l,d_in]x = rearrange(x, 'b l d_in -> b d_in l')  # 更换l与d_inx = self.conv1d(x)[:, :, :l]  # 1维卷积x = rearrange(x, 'b d_in l -> b l d_in')  # 变回来x = F.silu(x)  # 使用silu激活函数y = self.ssm(x)y = y * F.silu(res)output = self.out_proj(y)return outputdef ssm(self, x):"""Runs the SSM. See:- Algorithm 2 in Section 3.2 in the Mamba paper [1]- run_SSM(A, B, C, u) in The Annotated S4 [2]Args:x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)Returns:output: shape (b, l, d_in)Official Implementation:mamba_inner_ref(),"""(d_in, n) = self.A_log.shape# Compute ∆ A B C D, the state space parameters.#     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)#     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,#                                  and is why Mamba is called **selective** state spaces)A = -torch.exp(self.A_log.float())  # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n],dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return ydef selective_scan(self, u, delta, A, B, C, D):"""Does selective scan algorithm. See:- Section 2 State Space Models in the Mamba paper [1]- Algorithm 2 in Section 3.2 in the Mamba paper [1]- run_SSM(A, B, C, u) in The Annotated S4 [2]This is the classic discrete state space formula:x(t + 1) = Ax(t) + Bu(t)y(t)     = Cx(t) + Du(t)except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).Args:u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)delta: shape (b, l, d_in)A: shape (d_in, n)B: shape (b, l, n)C: shape (b, l, n)D: shape (d_in,)Returns:output: shape (b, l, d_in)Official Implementation:selective_scan_ref(), I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly."""(b, l, d_in) = u.shapen = A.shape[1]# Discretize continuous parameters (A, B)# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:#   "A is the more important term and the performance doesn't change much with the simplification on B"deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])# Note that the below is sequential, while the official implementation does a much faster parallel scan that# is additionally hardware-aware (like FlashAttention).x = torch.zeros((b, d_in, n), device=deltaA.device)ys = []for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1)  # shape (b, l, d_in)y = y + u * Dreturn yclass RMSNorm(nn.Module):def __init__(self,d_model: int,eps: float = 1e-5):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(d_model))def forward(self, x):output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weightreturn outputif __name__ == '__main__':# 创建一个简单的Mamba模型实例vocab_size = 32000n_layer = 2d_model = 128model_args = ModelArgs(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size)model_args.__post_init__()mamba_model = Mamba(model_args)# 生成随机整数张量,元素范围在1到999之间, 输入为batch,length分别表示批量,一个句子长度,每个词对应索引input_data = torch.randint(low=1, high=vocab_size, size=(2, 200))output = mamba_model(input_data)print(output.shape)




《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(


《使用Navicat工具比对两个数据库所有表结构的差异案例详解》:本文主要介绍如何使用Navicat工具对比两个数据库test_old和test_new,并生成相应的DDLSQL语句,以便将te... 目录概要案例一、如图两个数据库test_old和test_new进行比较:二、开始比较总结概要公司存在多


《MySQL中的MVCC底层原理解读》本文详细介绍了MySQL中的多版本并发控制(MVCC)机制,包括版本链、ReadView以及在不同事务隔离级别下MVCC的工作原理,通过一个具体的示例演示了在可重... 目录简介ReadView版本链演示过程总结简介MVCC(Multi-Version Concurr


《关于Gateway路由匹配规则解读》本文详细介绍了SpringCloudGateway的路由匹配规则,包括基本概念、常用属性、实际应用以及注意事项,路由匹配规则决定了请求如何被转发到目标服务,是Ga... 目录Gateway路由匹配规则一、基本概念二、常用属性三、实际应用四、注意事项总结Gateway路由


《解读Redis秒杀优化方案(阻塞队列+基于Stream流的消息队列)》该文章介绍了使用Redis的阻塞队列和Stream流的消息队列来优化秒杀系统的方案,通过将秒杀流程拆分为两条流水线,使用Redi... 目录Redis秒杀优化方案(阻塞队列+Stream流的消息队列)什么是消息队列?消费者组的工作方式每


《解读静态资源访问static-locations和static-path-pattern》本文主要介绍了SpringBoot中静态资源的配置和访问方式,包括静态资源的默认前缀、默认地址、目录结构、访... 目录静态资源访问static-locations和static-path-pattern静态资源配置


《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操


《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon


《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本


《Redis过期键删除策略解读》Redis通过惰性删除策略和定期删除策略来管理过期键,惰性删除策略在键被访问时检查是否过期并删除,节省CPU开销但可能导致过期键滞留,定期删除策略定期扫描并删除过期键,... 目录1.Redis使用两种不同的策略来删除过期键,分别是惰性删除策略和定期删除策略1.1惰性删除策略