Mamba-minimal Mamba的最小限度实现 (一)

2024-03-09 06:20

本文主要是介绍Mamba-minimal Mamba的最小限度实现 (一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 参数和数据尺寸约定
    • class MambaBlock
      • def forward
      • def __ int__
      • def ssm
      • def selective_scan

johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现

参数和数据尺寸约定

之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示

参数及简写Mamba论文简写
batch_size bB
序列长度 lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 expandE
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / delta
delta秩 dt_rank

class MambaBlock

def forward

根据forward简单梳理MambaBlock的结构

在这里插入图片描述

中间变量来源shape
输入x(b, l, d_model)
x_and_resx经过输入映射后(b, l, 2* d_in)
x切分后作为ssm分支输入(b, l, d_in)
res切分后作为门控分支输入(b, l, d_in)
y经过卷积,激活,ssm,门控后的输出(b, l, d_in)
outputy经过输出映射后得到(b, l, d_model)
def forward(self, x):(b, l, d) = x.shapex_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)x = rearrange(x, 'b l d_in -> b d_in l')x = self.conv1d(x)[:, :, :l]x = rearrange(x, 'b d_in l -> b l d_in')x = F.silu(x)y = self.ssm(x)y = y * F.silu(res)output = self.out_proj(y)return output

def __ int__

初始化主要初始了几个部分

组件定义

操作及简写维度变换
输入映射 in_proj(b, l, d_model) -> (b, l, 2*d_in)
序列变换 conv1d只取前l (b, d_in, l) -> (b, d_in, l)
非线性激活 silu
输出映射 out_proj(b, l, d_in) -> (b, l, d)

在这里插入图片描述
ssm初始化

操作及简写作用
参数生成映射 x_proj生成数据依赖的参数B, C, Δ \Delta Δ
delta映射 dt_proj Δ \Delta Δ从dt_rank映射到d_in
矩阵A初始化简单初始化
矩阵D初始化简单初始化
def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)self.conv1d = nn.Conv1d(in_channels=args.d_inner,out_channels=args.d_inner,bias=args.conv_bias,kernel_size=args.d_conv,groups=args.d_inner,padding=args.d_conv - 1,)# ssm模型的初始化部分# x_proj takes in `x` and outputs the input-specific Δ, B, Cself.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)# dt_proj projects Δ from dt_rank to d_inself.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)self.A_log = nn.Parameter(torch.log(A))self.D = nn.Parameter(torch.ones(args.d_inner))self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

def ssm

这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。

SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖

其中一部分变量初始化于class MambaBlock的初始化部分

中间变量及简写来源
数据生成变量 x_db1x经过参数映射x_proj生成
最终delta Δ \Delta Δ切分而来的 Δ \Delta Δ经过映射和softplus
 def ssm(self, x):(d_in, n) = self.A_log.shapeA = -torch.exp(self.A_log.float())  # shape (d_in, n)D = self.D.float()x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]return y
SSM参数shape
状态矩阵A(d_in, n)
输入矩阵B(b, l, n)
输出矩阵C(b, l, n)
直接传递矩阵D(d_in)

def selective_scan

我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。

在这里插入图片描述

在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化

前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk1+ΔkBuk=(I+ΔA)x(t)+ΔBu(t)

零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk1+(ΔkA)1(eΔkAI)ΔkBuk=eΔAx(t)+(ΔA)1(eΔAI)ΔBu(t)

这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢

def selective_scan(self, u, delta, A, B, C, D):(b, l, d_in) = u.shapen = A.shape[1]deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')# Perform selective scan (see scan_SSM() in The Annotated S4 [2])x = torch.zeros((b, d_in, n), device=deltaA.device)ys = []    for i in range(l):x = deltaA[:, i] * x + deltaB_u[:, i]y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')ys.append(y)y = torch.stack(ys, dim=1)  # shape (b, l, d_in)y = y + u * Dreturn y

这篇关于Mamba-minimal Mamba的最小限度实现 (一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/789776

相关文章

C++使用栈实现括号匹配的代码详解

《C++使用栈实现括号匹配的代码详解》在编程中,括号匹配是一个常见问题,尤其是在处理数学表达式、编译器解析等任务时,栈是一种非常适合处理此类问题的数据结构,能够精确地管理括号的匹配问题,本文将通过C+... 目录引言问题描述代码讲解代码解析栈的状态表示测试总结引言在编程中,括号匹配是一个常见问题,尤其是在

Java实现检查多个时间段是否有重合

《Java实现检查多个时间段是否有重合》这篇文章主要为大家详细介绍了如何使用Java实现检查多个时间段是否有重合,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录流程概述步骤详解China编程步骤1:定义时间段类步骤2:添加时间段步骤3:检查时间段是否有重合步骤4:输出结果示例代码结语作

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import