本文主要是介绍Mamba-minimal Mamba的最小限度实现 (一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 参数和数据尺寸约定
- class MambaBlock
- def forward
- def __ int__
- def ssm
- def selective_scan
johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)
manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现
参数和数据尺寸约定
之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示
参数及简写 | Mamba论文简写 |
---|---|
batch_size b | B |
序列长度 l | L |
隐藏维度 d / d_model | |
潜在状态维度 n / d_state | N |
扩展因子 expand | E |
d_in / d_inner | D |
数据依赖步长 Δ \Delta Δ / delta | |
delta秩 dt_rank |
class MambaBlock
def forward
根据forward简单梳理MambaBlock的结构
中间变量 | 来源 | shape |
---|---|---|
输入x | (b, l, d_model) | |
x_and_res | x经过输入映射后 | (b, l, 2* d_in) |
x | 切分后作为ssm分支输入 | (b, l, d_in) |
res | 切分后作为门控分支输入 | (b, l, d_in) |
y | 经过卷积,激活,ssm,门控后的输出 | (b, l, d_in) |
output | y经过输出映射后得到 | (b, l, d_model) |
def forward(self, x):(b, l, d) = x.shapex_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)x = rearrange(x, 'b l d_in -> b d_in l')x = self.conv1d(x)[:, :, :l]x = rearrange(x, 'b d_in l -> b l d_in')x = F.silu(x)y = self.ssm(x)y = y * F.silu(res)output = self.out_proj(y)return output
def __ int__
初始化主要初始了几个部分
组件定义
操作及简写 | 维度变换 |
---|---|
输入映射 in_proj | (b, l, d_model) -> (b, l, 2*d_in) |
序列变换 conv1d | 只取前l (b, d_in, l) -> (b, d_in, l) |
非线性激活 silu | |
输出映射 out_proj | (b, l, d_in) -> (b, l, d) |
ssm初始化
操作及简写 | 作用 |
---|---|
参数生成映射 x_proj | 生成数据依赖的参数B, C, Δ \Delta Δ |
delta映射 dt_proj | 将 Δ \Delta Δ从dt_rank映射到d_in |
矩阵A初始化 | 简单初始化 |
矩阵D初始化 | 简单初始化 |
def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)self.conv1d = nn.Conv1d(in_channels=args.d_inner,out_channels=args.d_inner,bias=args.conv_bias,kernel_size=args.d_conv,groups=args.d_inner,padding=args.d_conv - 1,)# ssm模型的初始化部分# x_proj takes in `x` and outputs the input-specific Δ, B, Cself.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)# dt_proj projects Δ from dt_rank to d_inself.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)self.A_log = nn.Parameter(torch.log(A))self.D = nn.Parameter(torch.ones(args.d_inner))self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def ssm
这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。
SSM参数 | shape | 来源 |
---|---|---|
状态矩阵A | (d_in, n) | 在初始化中定义,非数据依赖 |
输入矩阵B | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
输出矩阵C | (b, l, n) | 由x_db1切分而来,因此数据依赖 |
直接传递矩阵D | (d_in) | 在初始化中定义,非数据依赖 |
数据依赖步长 Δ \Delta Δ | (b, l, d_in) | 由x_db1切分而来,因此数据依赖 |
其中一部分变量初始化于class MambaBlock的初始化部分
中间变量及简写 | 来源 |
---|---|
数据生成变量 x_db1 | x经过参数映射x_proj生成 |
最终delta Δ \Delta Δ | 切分而来的 Δ \Delta Δ经过映射和softplus |
def ssm(self, x):(d_in, n) = self.A_log.shapeA = -torch.exp(self.A_log.float()) # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return y
SSM参数 | shape |
---|---|
状态矩阵A | (d_in, n) |
输入矩阵B | (b, l, n) |
输出矩阵C | (b, l, n) |
直接传递矩阵D | (d_in) |
def selective_scan
我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。
在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化
前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk−1+ΔkB⋅uk=(I+ΔA)x(t)+ΔB⋅u(t)
零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk−1+(ΔkA)−1(eΔkA−I)⋅ΔkB⋅uk=eΔAx(t)+(ΔA)−1(eΔA−I)⋅ΔB⋅u(t)
这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢
def selective_scan(self, u, delta, A, B, C, D):(b, l, d_in) = u.shapen = A.shape[1]deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])x = torch.zeros((b, d_in, n), device=deltaA.device)ys = [] for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1) # shape (b, l, d_in)y = y + u * Dreturn y
这篇关于Mamba-minimal Mamba的最小限度实现 (一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!