Mamba-minimal Mamba的最小限度实现 (一)

2024-03-09 06:20

本文主要是介绍Mamba-minimal Mamba的最小限度实现 (一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 参数和数据尺寸约定
    • class MambaBlock
      • def forward
      • def __ int__
      • def ssm
      • def selective_scan

johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现

参数和数据尺寸约定

之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示

参数及简写Mamba论文简写
batch_size bB
序列长度 lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 expandE
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / delta
delta秩 dt_rank

class MambaBlock

def forward

根据forward简单梳理MambaBlock的结构

在这里插入图片描述

中间变量来源shape
输入x(b, l, d_model)
x_and_resx经过输入映射后(b, l, 2* d_in)
x切分后作为ssm分支输入(b, l, d_in)
res切分后作为门控分支输入(b, l, d_in)
y经过卷积,激活,ssm,门控后的输出(b, l, d_in)
outputy经过输出映射后得到(b, l, d_model)
def forward(self, x):(b, l, d) = x.shapex_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)x = rearrange(x, 'b l d_in -> b d_in l')x = self.conv1d(x)[:, :, :l]x = rearrange(x, 'b d_in l -> b l d_in')x = F.silu(x)y = self.ssm(x)y = y * F.silu(res)output = self.out_proj(y)return output

def __ int__

初始化主要初始了几个部分

组件定义

操作及简写维度变换
输入映射 in_proj(b, l, d_model) -> (b, l, 2*d_in)
序列变换 conv1d只取前l (b, d_in, l) -> (b, d_in, l)
非线性激活 silu
输出映射 out_proj(b, l, d_in) -> (b, l, d)

在这里插入图片描述
ssm初始化

操作及简写作用
参数生成映射 x_proj生成数据依赖的参数B, C, Δ \Delta Δ
delta映射 dt_proj Δ \Delta Δ从dt_rank映射到d_in
矩阵A初始化简单初始化
矩阵D初始化简单初始化
def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)self.conv1d = nn.Conv1d(in_channels=args.d_inner,out_channels=args.d_inner,bias=args.conv_bias,kernel_size=args.d_conv,groups=args.d_inner,padding=args.d_conv - 1,)# ssm模型的初始化部分# x_proj takes in `x` and outputs the input-specific Δ, B, Cself.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)# dt_proj projects Δ from dt_rank to d_inself.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)self.A_log = nn.Parameter(torch.log(A))self.D = nn.Parameter(torch.ones(args.d_inner))self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

def ssm

这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。

SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖

其中一部分变量初始化于class MambaBlock的初始化部分

中间变量及简写来源
数据生成变量 x_db1x经过参数映射x_proj生成
最终delta Δ \Delta Δ切分而来的 Δ \Delta Δ经过映射和softplus
 def ssm(self, x):(d_in, n) = self.A_log.shapeA = -torch.exp(self.A_log.float())  # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return y
SSM参数shape
状态矩阵A(d_in, n)
输入矩阵B(b, l, n)
输出矩阵C(b, l, n)
直接传递矩阵D(d_in)

def selective_scan

我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。

在这里插入图片描述

在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化

前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk1+ΔkBuk=(I+ΔA)x(t)+ΔBu(t)

零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk1+(ΔkA)1(eΔkAI)ΔkBuk=eΔAx(t)+(ΔA)1(eΔAI)ΔBu(t)

这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢

def selective_scan(self, u, delta, A, B, C, D):(b, l, d_in) = u.shapen = A.shape[1]deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])x = torch.zeros((b, d_in, n), device=deltaA.device)ys = []    for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1)  # shape (b, l, d_in)y = y + u * Dreturn y

这篇关于Mamba-minimal Mamba的最小限度实现 (一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/789776

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

C++——stack、queue的实现及deque的介绍

目录 1.stack与queue的实现 1.1stack的实现  1.2 queue的实现 2.重温vector、list、stack、queue的介绍 2.1 STL标准库中stack和queue的底层结构  3.deque的简单介绍 3.1为什么选择deque作为stack和queue的底层默认容器  3.2 STL中对stack与queue的模拟实现 ①stack模拟实现

基于51单片机的自动转向修复系统的设计与实现

文章目录 前言资料获取设计介绍功能介绍设计清单具体实现截图参考文献设计获取 前言 💗博主介绍:✌全网粉丝10W+,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们电子相关专业的大学生,希望您们都共创辉煌!✌💗 👇🏻 精彩专栏 推荐订阅👇🏻 单片机