Mamba-minimal Mamba的最小限度实现 (一)

2024-03-09 06:20

本文主要是介绍Mamba-minimal Mamba的最小限度实现 (一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 参数和数据尺寸约定
    • class MambaBlock
      • def forward
      • def __ int__
      • def ssm
      • def selective_scan

johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现

参数和数据尺寸约定

之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示

参数及简写Mamba论文简写
batch_size bB
序列长度 lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 expandE
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / delta
delta秩 dt_rank

class MambaBlock

def forward

根据forward简单梳理MambaBlock的结构

在这里插入图片描述

中间变量来源shape
输入x(b, l, d_model)
x_and_resx经过输入映射后(b, l, 2* d_in)
x切分后作为ssm分支输入(b, l, d_in)
res切分后作为门控分支输入(b, l, d_in)
y经过卷积,激活,ssm,门控后的输出(b, l, d_in)
outputy经过输出映射后得到(b, l, d_model)
def forward(self, x):(b, l, d) = x.shapex_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)x = rearrange(x, 'b l d_in -> b d_in l')x = self.conv1d(x)[:, :, :l]x = rearrange(x, 'b d_in l -> b l d_in')x = F.silu(x)y = self.ssm(x)y = y * F.silu(res)output = self.out_proj(y)return output

def __ int__

初始化主要初始了几个部分

组件定义

操作及简写维度变换
输入映射 in_proj(b, l, d_model) -> (b, l, 2*d_in)
序列变换 conv1d只取前l (b, d_in, l) -> (b, d_in, l)
非线性激活 silu
输出映射 out_proj(b, l, d_in) -> (b, l, d)

在这里插入图片描述
ssm初始化

操作及简写作用
参数生成映射 x_proj生成数据依赖的参数B, C, Δ \Delta Δ
delta映射 dt_proj Δ \Delta Δ从dt_rank映射到d_in
矩阵A初始化简单初始化
矩阵D初始化简单初始化
def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)self.conv1d = nn.Conv1d(in_channels=args.d_inner,out_channels=args.d_inner,bias=args.conv_bias,kernel_size=args.d_conv,groups=args.d_inner,padding=args.d_conv - 1,)# ssm模型的初始化部分# x_proj takes in `x` and outputs the input-specific Δ, B, Cself.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)# dt_proj projects Δ from dt_rank to d_inself.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)self.A_log = nn.Parameter(torch.log(A))self.D = nn.Parameter(torch.ones(args.d_inner))self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

def ssm

这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。

SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖

其中一部分变量初始化于class MambaBlock的初始化部分

中间变量及简写来源
数据生成变量 x_db1x经过参数映射x_proj生成
最终delta Δ \Delta Δ切分而来的 Δ \Delta Δ经过映射和softplus
 def ssm(self, x):(d_in, n) = self.A_log.shapeA = -torch.exp(self.A_log.float())  # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return y
SSM参数shape
状态矩阵A(d_in, n)
输入矩阵B(b, l, n)
输出矩阵C(b, l, n)
直接传递矩阵D(d_in)

def selective_scan

我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。

在这里插入图片描述

在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化

前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk1+ΔkBuk=(I+ΔA)x(t)+ΔBu(t)

零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk1+(ΔkA)1(eΔkAI)ΔkBuk=eΔAx(t)+(ΔA)1(eΔAI)ΔBu(t)

这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢

def selective_scan(self, u, delta, A, B, C, D):(b, l, d_in) = u.shapen = A.shape[1]deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])x = torch.zeros((b, d_in, n), device=deltaA.device)ys = []    for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1)  # shape (b, l, d_in)y = y + u * Dreturn y

这篇关于Mamba-minimal Mamba的最小限度实现 (一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/789776

相关文章

golang版本升级如何实现

《golang版本升级如何实现》:本文主要介绍golang版本升级如何实现问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录golanwww.chinasem.cng版本升级linux上golang版本升级删除golang旧版本安装golang最新版本总结gola

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

Mysql实现范围分区表(新增、删除、重组、查看)

《Mysql实现范围分区表(新增、删除、重组、查看)》MySQL分区表的四种类型(范围、哈希、列表、键值),主要介绍了范围分区的创建、查询、添加、删除及重组织操作,具有一定的参考价值,感兴趣的可以了解... 目录一、mysql分区表分类二、范围分区(Range Partitioning1、新建分区表:2、分

MySQL 定时新增分区的实现示例

《MySQL定时新增分区的实现示例》本文主要介绍了通过存储过程和定时任务实现MySQL分区的自动创建,解决大数据量下手动维护的繁琐问题,具有一定的参考价值,感兴趣的可以了解一下... mysql创建好分区之后,有时候会需要自动创建分区。比如,一些表数据量非常大,有些数据是热点数据,按照日期分区MululbU

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2