深入浅出一文图解Vision Mamba(ViM)

2024-04-30 06:44

本文主要是介绍深入浅出一文图解Vision Mamba(ViM),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 引言:Mamba
    • 第一章:环境安装
      • 1.1安装教程
      • 1.2问题总结
      • 1.3安装总结
    • 第二章:即插即用模块
      • 2.1模块一:Mamba Vision
        • 代码:models_mamba.py
        • 运行结果
      • 2.2模块二:MambaIR
        • 代码:MambaIR
        • 运行结果
    • 第三章:经典文献阅读与追踪
      • 经典论文
      • Mamba系列论文追踪
    • 第四章:Mamba理论与分析
    • 第五章:总结和展望


引言:Mamba

2024年04月29日16:06:08,今天开始记录mamba模块的学习与使用过程。


第一章:环境安装

亲测,根据下文的安装步骤,即可成功!

使用代码Vision Mamba:https://github.com/hustvl/Vim

git clone https://github.com/hustvl/Vim.git

1.1安装教程

安装教程:下载好vision mamba后,根据下面的教程一步一步安装即可成功。

vision mamba 运行训练记录,解决bimamba_type错误

1.2问题总结

问题总结:遇见的问题可以参考这个链接,总结的比较全面。

Mamba 环境安装踩坑问题汇总及解决方法

1.3安装总结

关键就是下载causal_conv1dmamba_ssm,最好是下载离线的whl文件,然后再用pip进行安装。值得注意的一点就是要用官方项目里的mamba_ssm替换安装在conda环境里的mamba_ssm。


第二章:即插即用模块

2.1模块一:Mamba Vision

Github:https://github.com/hustvl/Vim;
下载代码,配置好环境后,用下面的代码替换Vim/vim/models_mamba.py,即可直接运行;

运行指令

python models_mamba.py
代码:models_mamba.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from torch import Tensor
from typing import Optionalfrom timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, lecun_normal_from timm.models.layers import DropPath, to_2tuple
from timm.models.vision_transformer import _load_weightsimport mathfrom collections import namedtuplefrom mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hffrom rope import *
import randomtry:from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None__all__ = ['vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224','vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
]class PatchEmbed(nn.Module):""" 2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):super().__init__()img_size = to_2tuple(img_size)patch_size = to_2tuple(patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)self.num_patches = self.grid_size[0] * self.grid_size[1]self.flatten = flattenself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCHW -> BNCx = self.norm(x)return xclass Block(nn.Module):def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0.,):"""Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"This Block has a slightly different structure compared to a regularprenorm Transformer block.The standard block is: LN -> MHA/MLP -> Add.[Ref: https://arxiv.org/abs/2002.04745]Here we have: Add -> LN -> Mixer, returning boththe hidden_states (output of the mixer) and the residual.This is purely for performance reasons, as we can fuse add and LayerNorm.The residual needs to be provided (except for the very first block)."""super().__init__()self.residual_in_fp32 = residual_in_fp32self.fused_add_norm = fused_add_normself.mixer = mixer_cls(dim)self.norm = norm_cls(dim)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()if self.fused_add_norm:assert RMSNorm is not None, "RMSNorm import fails"assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):r"""Pass the input through the encoder layer.Args:hidden_states: the sequence to the encoder layer (required).residual: hidden_states = Mixer(LN(residual))"""if not self.fused_add_norm:if residual is None:residual = hidden_stateselse:residual = residual + self.drop_path(hidden_states)hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))if self.residual_in_fp32:residual = residual.to(torch.float32)else:fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fnif residual is None:hidden_states, residual = fused_add_norm_fn(hidden_states,self.norm.weight,self.norm.bias,residual=residual,prenorm=True,residual_in_fp32=self.residual_in_fp32,eps=self.norm.eps,)else:hidden_states, residual = fused_add_norm_fn(self.drop_path(hidden_states),self.norm.weight,self.norm.bias,residual=residual,prenorm=True,residual_in_fp32=self.residual_in_fp32,eps=self.norm.eps,)    hidden_states = self.mixer(hidden_states, inference_params=inference_params)return hidden_states, residualdef allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)def create_block(d_model,ssm_cfg=None,norm_epsilon=1e-5,drop_path=0.,rms_norm=False,residual_in_fp32=False,fused_add_norm=False,layer_idx=None,device=None,dtype=None,if_bimamba=False,bimamba_type="none",if_devide_out=False,init_layer_scale=None,
):if if_bimamba:bimamba_type = "v1"if ssm_cfg is None:ssm_cfg = {}factory_kwargs = {"device": device, "dtype": dtype}mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)block = Block(d_model,mixer_cls,norm_cls=norm_cls,drop_path=drop_path,fused_add_norm=fused_add_norm,residual_in_fp32=residual_in_fp32,)block.layer_idx = layer_idxreturn block# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module,n_layer,initializer_range=0.02,  # Now only used for embedding layer.rescale_prenorm_residual=True,n_residuals_per_layer=1,  # Change to 2 if we have MLP
):if isinstance(module, nn.Linear):if module.bias is not None:if not getattr(module.bias, "_no_reinit", False):nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):nn.init.normal_(module.weight, std=initializer_range)if rescale_prenorm_residual:# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:#   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale#   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.#   >   -- GPT-2 :: https://openai.com/blog/better-language-models/## Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.pyfor name, p in module.named_parameters():if name in ["out_proj.weight", "fc2.weight"]:# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)# We need to reinit p since this code could be called multiple times# Having just p *= scale would repeatedly scale it downnn.init.kaiming_uniform_(p, a=math.sqrt(5))with torch.no_grad():p /= math.sqrt(n_residuals_per_layer * n_layer)def segm_init_weights(m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=0.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Conv2d):# NOTE conv was left to pytorch default in my original initlecun_normal_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):nn.init.zeros_(m.bias)nn.init.ones_(m.weight)class VisionMamba(nn.Module):def __init__(self, img_size=224, patch_size=16, stride=16,depth=24, embed_dim=192, channels=3, num_classes=1000,ssm_cfg=None, drop_rate=0.,drop_path_rate=0.1,norm_epsilon: float = 1e-5, rms_norm: bool = False, initializer_cfg=None,fused_add_norm=False,residual_in_fp32=False,device=None,dtype=None,ft_seq_len=None,pt_hw_seq_len=14,if_bidirectional=False,final_pool_type='none',if_abs_pos_embed=False,if_rope=False,if_rope_residual=False,flip_img_sequences_ratio=-1.,if_bimamba=False,bimamba_type="none",if_cls_token=False,if_devide_out=False,init_layer_scale=None,use_double_cls_token=False,use_middle_cls_token=False,**kwargs):factory_kwargs = {"device": device, "dtype": dtype}# add factory_kwargs into kwargskwargs.update(factory_kwargs) super().__init__()self.residual_in_fp32 = residual_in_fp32self.fused_add_norm = fused_add_normself.if_bidirectional = if_bidirectionalself.final_pool_type = final_pool_typeself.if_abs_pos_embed = if_abs_pos_embedself.if_rope = if_ropeself.if_rope_residual = if_rope_residualself.flip_img_sequences_ratio = flip_img_sequences_ratioself.if_cls_token = if_cls_tokenself.use_double_cls_token = use_double_cls_tokenself.use_middle_cls_token = use_middle_cls_tokenself.num_tokens = 1 if if_cls_token else 0# pretrain parametersself.num_classes = num_classesself.d_model = self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other modelsself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)num_patches = self.patch_embed.num_patchesif if_cls_token:if use_double_cls_token:self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))self.num_tokens = 2else:self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))# self.num_tokens = 1if if_abs_pos_embed:self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)if if_rope:half_head_dim = embed_dim // 2hw_seq_len = img_size // patch_sizeself.rope = VisionRotaryEmbeddingFast(dim=half_head_dim,pt_seq_len=pt_hw_seq_len,ft_seq_len=hw_seq_len)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()# TODO: release this commentdpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule# import ipdb;ipdb.set_trace()inter_dpr = [0.0] + dprself.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()# transformer blocksself.layers = nn.ModuleList([create_block(embed_dim,ssm_cfg=ssm_cfg,norm_epsilon=norm_epsilon,rms_norm=rms_norm,residual_in_fp32=residual_in_fp32,fused_add_norm=fused_add_norm,layer_idx=i,if_bimamba=if_bimamba,bimamba_type=bimamba_type,drop_path=inter_dpr[i],if_devide_out=if_devide_out,init_layer_scale=init_layer_scale,**factory_kwargs,)for i in range(depth)])# output headself.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs)# self.pre_logits = nn.Identity()# original initself.patch_embed.apply(segm_init_weights)self.head.apply(segm_init_weights)if if_abs_pos_embed:trunc_normal_(self.pos_embed, std=.02)if if_cls_token:if use_double_cls_token:trunc_normal_(self.cls_token_head, std=.02)trunc_normal_(self.cls_token_tail, std=.02)else:trunc_normal_(self.cls_token, std=.02)# mamba initself.apply(partial(_init_weights,n_layer=depth,**(initializer_cfg if initializer_cfg is not None else {}),))def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)for i, layer in enumerate(self.layers)}@torch.jit.ignoredef no_weight_decay(self):return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}@torch.jit.ignore()def load_pretrained(self, checkpoint_path, prefix=""):_load_weights(self, checkpoint_path, prefix)def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py# with slight modifications to add the dist_tokenx = self.patch_embed(x)B, M, _ = x.shapeif self.if_cls_token:if self.use_double_cls_token:cls_token_head = self.cls_token_head.expand(B, -1, -1)cls_token_tail = self.cls_token_tail.expand(B, -1, -1)token_position = [0, M + 1]x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)M = x.shape[1]else:if self.use_middle_cls_token:cls_token = self.cls_token.expand(B, -1, -1)token_position = M // 2# add cls token in the middlex = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)elif if_random_cls_token_position:cls_token = self.cls_token.expand(B, -1, -1)token_position = random.randint(0, M)x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)print("token_position: ", token_position)else:cls_token = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thankstoken_position = 0x = torch.cat((cls_token, x), dim=1)M = x.shape[1]if self.if_abs_pos_embed:# if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:#     x = x + self.pos_embed# else:#     pos_embed = interpolate_pos_embed_online(#                 self.pos_embed, self.patch_embed.grid_size, new_grid_size,0#             )x = x + self.pos_embedx = self.pos_drop(x)if if_random_token_rank:# 生成随机 shuffle 索引shuffle_indices = torch.randperm(M)if isinstance(token_position, list):print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])else:print("original value: ", x[0, token_position, 0])print("original token_position: ", token_position)# 执行 shufflex = x[:, shuffle_indices, :]if isinstance(token_position, list):# 找到 cls token 在 shuffle 之后的新位置new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]token_position = new_token_positionelse:# 找到 cls token 在 shuffle 之后的新位置token_position = torch.where(shuffle_indices == token_position)[0].item()if isinstance(token_position, list):print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])else:print("new value: ", x[0, token_position, 0])print("new token_position: ", token_position)if_flip_img_sequences = Falseif self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:x = x.flip([1])if_flip_img_sequences = True# mamba implresidual = Nonehidden_states = xif not self.if_bidirectional:for layer in self.layers:if if_flip_img_sequences and self.if_rope:hidden_states = hidden_states.flip([1])if residual is not None:residual = residual.flip([1])# rope aboutif self.if_rope:hidden_states = self.rope(hidden_states)if residual is not None and self.if_rope_residual:residual = self.rope(residual)if if_flip_img_sequences and self.if_rope:hidden_states = hidden_states.flip([1])if residual is not None:residual = residual.flip([1])hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)else:# get two layers in a single for-loopfor i in range(len(self.layers) // 2):if self.if_rope:hidden_states = self.rope(hidden_states)if residual is not None and self.if_rope_residual:residual = self.rope(residual)hidden_states_f, residual_f = self.layers[i * 2](hidden_states, residual, inference_params=inference_params)hidden_states_b, residual_b = self.layers[i * 2 + 1](hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params)hidden_states = hidden_states_f + hidden_states_b.flip([1])residual = residual_f + residual_b.flip([1])if not self.fused_add_norm:if residual is None:residual = hidden_stateselse:residual = residual + self.drop_path(hidden_states)hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))else:# Set prenorm=False here since we don't need the residualfused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fnhidden_states = fused_add_norm_fn(self.drop_path(hidden_states),self.norm_f.weight,self.norm_f.bias,eps=self.norm_f.eps,residual=residual,prenorm=False,residual_in_fp32=self.residual_in_fp32,)# return only cls token if it existsif self.if_cls_token:if self.use_double_cls_token:return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2else:if self.use_middle_cls_token:return hidden_states[:, token_position, :]elif if_random_cls_token_position:return hidden_states[:, token_position, :]else:return hidden_states[:, token_position, :]if self.final_pool_type == 'none':return hidden_states[:, -1, :]elif self.final_pool_type == 'mean':return hidden_states.mean(dim=1)elif self.final_pool_type == 'max':return hidden_stateselif self.final_pool_type == 'all':return hidden_stateselse:raise NotImplementedErrordef forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)if return_features:return xx = self.head(x)if self.final_pool_type == 'max':x = x.max(dim=1)[0]return x@register_model
def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):model = VisionMamba(patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)model.default_cfg = _cfg()if pretrained:checkpoint = torch.hub.load_state_dict_from_url(url="to.do",map_location="cpu", check_hash=True)model.load_state_dict(checkpoint["model"])return model@register_model
def vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):model = VisionMamba(patch_size=16, stride=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)model.default_cfg = _cfg()if pretrained:checkpoint = torch.hub.load_state_dict_from_url(url="to.do",map_location="cpu", check_hash=True)model.load_state_dict(checkpoint["model"])return model@register_model
def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):model = VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)model.default_cfg = _cfg()if pretrained:checkpoint = torch.hub.load_state_dict_from_url(url="to.do",map_location="cpu", check_hash=True)model.load_state_dict(checkpoint["model"])return model@register_model
def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)model.default_cfg = _cfg()if pretrained:checkpoint = torch.hub.load_state_dict_from_url(url="to.do",map_location="cpu", check_hash=True)model.load_state_dict(checkpoint["model"])return modelif __name__ == '__main__':# cuda or cpudevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)# 实例化模型得到分类结果inputs = torch.randn(1, 3, 224, 224).to(device)model = vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False).to(device)print(model)outputs = model(inputs)print(outputs.shape)# 实例化mamba模块,输入输出特征维度不变 B C H Wx = torch.rand(10, 16, 64, 128).to(device)B, C, H, W = x.shapeprint("输入特征维度:", x.shape)x = x.view(B, C, H * W).permute(0, 2, 1)print("维度变换:", x.shape)mamba = create_block(d_model=C).to(device)# mamba模型代码中返回的是一个元组:hidden_states, residualhidden_states, residual = mamba(x)x = hidden_states.permute(0, 2, 1).view(B, C, H, W)print("输出特征维度:", x.shape)
运行结果

在这里插入图片描述


2.2模块二:MambaIR

B站UP主:@箫张跋扈

视频地址:Mamba Back!一种来自于Mamba领域的即插即用模块(TimeMachine),用于时间序列任务!

下载好代码后,把下面的代码放到MambaIR.py文件中,然后再运行即可得到结果。

代码:MambaIR
# Code Implementation of the MambaIR Model
import warnings
warnings.filterwarnings("ignore")
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Optional, Callable
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
from einops import rearrange, repeat"""
最近,选择性结构化状态空间模型,特别是改进版本的Mamba,在线性复杂度的远程依赖建模方面表现出了巨大的潜力。
然而,标准Mamba在低级视觉方面仍然面临一定的挑战,例如局部像素遗忘和通道冗余。在这项工作中,我们引入了局部增强和通道注意力来改进普通 Mamba。
通过这种方式,我们利用了局部像素相似性并减少了通道冗余。大量的实验证明了我们方法的优越性。
"""NEG_INF = -1000000class ChannelAttention(nn.Module):"""Channel attention used in RCAN.Args:num_feat (int): Channel number of intermediate features.squeeze_factor (int): Channel squeeze factor. Default: 16."""def __init__(self, num_feat, squeeze_factor=16):super(ChannelAttention, self).__init__()self.attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),nn.ReLU(inplace=True),nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),nn.Sigmoid())def forward(self, x):y = self.attention(x)return x * yclass CAB(nn.Module):def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30):super(CAB, self).__init__()if is_light_sr: # we use depth-wise conv for light-SR to achieve more efficientself.cab = nn.Sequential(nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=num_feat),ChannelAttention(num_feat, squeeze_factor))else: # for classic SRself.cab = nn.Sequential(nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),nn.GELU(),nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),ChannelAttention(num_feat, squeeze_factor))def forward(self, x):return self.cab(x)class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass DynamicPosBias(nn.Module):def __init__(self, dim, num_heads):super().__init__()self.num_heads = num_headsself.pos_dim = dim // 4self.pos_proj = nn.Linear(2, self.pos_dim)self.pos1 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim),)self.pos2 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.pos_dim))self.pos3 = nn.Sequential(nn.LayerNorm(self.pos_dim),nn.ReLU(inplace=True),nn.Linear(self.pos_dim, self.num_heads))def forward(self, biases):pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))return posdef flops(self, N):flops = N * 2 * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.pos_dimflops += N * self.pos_dim * self.num_headsreturn flopsclass Attention(nn.Module):r""" Multi-head self attention module with dynamic position bias.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,position_bias=True):super().__init__()self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.position_bias = position_biasif self.position_bias:self.pos = DynamicPosBias(self.dim // 4, self.num_heads)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)def forward(self, x, H, W, mask=None):"""Args:x: input features with shape of (num_groups*B, N, C)mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or NoneH: height of each groupW: width of each group"""group_size = (H, W)B_, N, C = x.shapeassert H * W == Nqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()q, k, v = qkv[0], qkv[1], qkv[2]q = q * self.scaleattn = (q @ k.transpose(-2, -1))  # (B_, self.num_heads, N, N), N = H*Wif self.position_bias:# generate mother-setposition_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device)position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device)biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Gh-1, 2W2-1biases = biases.flatten(1).transpose(0, 1).contiguous().float()  # (2h-1)*(2w-1) 2# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(group_size[0], device=attn.device)coords_w = torch.arange(group_size[1], device=attn.device)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Gh, Gwcoords_flatten = torch.flatten(coords, 1)  # 2, Gh*Gwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Gh*Gw, Gh*Gwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Gh*Gw, Gh*Gw, 2relative_coords[:, :, 0] += group_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += group_size[1] - 1relative_coords[:, :, 0] *= 2 * group_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Gh*Gw, Gh*Gwpos = self.pos(biases)  # 2Gh-1 * 2Gw-1, heads# select position biasrelative_position_bias = pos[relative_position_index.view(-1)].view(group_size[0] * group_size[1], group_size[0] * group_size[1], -1)  # Gh*Gw,Gh*Gw,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Gh*Gw, Gh*Gwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nP = mask.shape[0]attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)  # (B, nP, nHead, N, N)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass SS2D(nn.Module):def __init__(self,d_model,d_state=16,d_conv=3,expand=2.,dt_rank="auto",dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0,dt_init_floor=1e-4,dropout=0.,conv_bias=True,bias=False,device=None,dtype=None,**kwargs,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_modelself.d_state = d_stateself.d_conv = d_convself.expand = expandself.d_inner = int(self.expand * self.d_model)self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rankself.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)self.conv2d = nn.Conv2d(in_channels=self.d_inner,out_channels=self.d_inner,groups=self.d_inner,bias=conv_bias,kernel_size=d_conv,padding=(d_conv - 1) // 2,**factory_kwargs,)self.act = nn.SiLU()self.x_proj = (nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),)self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))  # (K=4, N, inner)del self.x_projself.dt_projs = (self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,**factory_kwargs),self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,**factory_kwargs),self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,**factory_kwargs),self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,**factory_kwargs),)self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0))  # (K=4, inner, rank)self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0))  # (K=4, inner)del self.dt_projsself.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True)  # (K=4, D, N)self.Ds = self.D_init(self.d_inner, copies=4, merge=True)  # (K=4, D, N)self.selective_scan = selective_scan_fnself.out_norm = nn.LayerNorm(self.d_inner)self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)self.dropout = nn.Dropout(dropout) if dropout > 0. else None@staticmethoddef dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,**factory_kwargs):dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)# Initialize special dt projection to preserve variance at initializationdt_init_std = dt_rank ** -0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_maxdt = torch.exp(torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():dt_proj.bias.copy_(inv_dt)# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinitdt_proj.bias._no_reinit = Truereturn dt_proj@staticmethoddef A_log_init(d_state, d_inner, copies=1, device=None, merge=True):# S4D real initializationA = repeat(torch.arange(1, d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=d_inner,).contiguous()A_log = torch.log(A)  # Keep A_log in fp32if copies > 1:A_log = repeat(A_log, "d n -> r d n", r=copies)if merge:A_log = A_log.flatten(0, 1)A_log = nn.Parameter(A_log)A_log._no_weight_decay = Truereturn A_log@staticmethoddef D_init(d_inner, copies=1, device=None, merge=True):# D "skip" parameterD = torch.ones(d_inner, device=device)if copies > 1:D = repeat(D, "n1 -> r n1", r=copies)if merge:D = D.flatten(0, 1)D = nn.Parameter(D)  # Keep in fp32D._no_weight_decay = Truereturn Ddef forward_core(self, x: torch.Tensor):B, C, H, W = x.shapeL = H * WK = 4x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136)x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)xs = xs.float().view(B, -1, L)dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)Bs = Bs.float().view(B, K, -1, L)Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)Ds = self.Ds.float().view(-1)As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)out_y = self.selective_scan(xs, dts,As, Bs, Cs, Ds, z=None,delta_bias=dt_projs_bias,delta_softplus=True,return_last_state=False,).view(B, K, -1, L)assert out_y.dtype == torch.floatinv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)return out_y[:, 0], inv_y[:, 0], wh_y, invwh_ydef forward(self, x: torch.Tensor, **kwargs):B, H, W, C = x.shapexz = self.in_proj(x)x, z = xz.chunk(2, dim=-1)x = x.permute(0, 3, 1, 2).contiguous()x = self.act(self.conv2d(x))y1, y2, y3, y4 = self.forward_core(x)assert y1.dtype == torch.float32y = y1 + y2 + y3 + y4y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)y = self.out_norm(y)y = y * F.silu(z)out = self.out_proj(y)if self.dropout is not None:out = self.dropout(out)return outclass VSSBlock(nn.Module):def __init__(self,hidden_dim: int = 0,drop_path: float = 0,norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),attn_drop_rate: float = 0,d_state: int = 16,expand: float = 2.,is_light_sr: bool = False,**kwargs,):super().__init__()self.ln_1 = norm_layer(hidden_dim)self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs)self.drop_path = DropPath(drop_path)self.skip_scale= nn.Parameter(torch.ones(hidden_dim))self.conv_blk = CAB(hidden_dim,is_light_sr)self.ln_2 = nn.LayerNorm(hidden_dim)self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim))def forward(self, input, x_size):# x [B,HW,C]B, L, C = input.shapeinput = input.view(B, *x_size, C).contiguous()  # [B,H,W,C]x = self.ln_1(input)x = input*self.skip_scale + self.drop_path(self.self_attention(x))x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()x = x.view(B, -1, C).contiguous()return xif __name__ == '__main__':# 初始化VSSBlock模块,hidden_dim为128block = VSSBlock(hidden_dim=128, drop_path=0.1, attn_drop_rate=0.1, d_state=16, expand=2.0, is_light_sr=False)# 将模块转移到合适的设备上device = torch.device("cuda" if torch.cuda.is_available() else "cpu")block = block.to(device)# 生成随机输入张量,尺寸为[B, H*W, C],这里模拟的是批次大小为4,每个图像的尺寸是32x32,通道数为128B, H, W, C = 4, 32, 32, 128input_tensor = torch.rand(B, H * W, C).to(device)# 计算输出output_tensor = block(input_tensor, (H, W))# 打印输入和输出张量的尺寸print("Input tensor size:", input_tensor.size())print("Output tensor size:", output_tensor.size())
运行结果

在这里插入图片描述


第三章:经典文献阅读与追踪

Mamba原文:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

经典论文

  1. Vision Mamba@Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model
  2. MambaIR@MambaIR: A Simple Baseline for Image Restoration with State-Space Model
  3. U-Mamba@U-Mamba: Enhancing Long-range Dependency for Biomedical Image Segmentation

Mamba系列论文追踪

Github链接会分享不同领域基于Mamba结构的论文

Mamba_State_Space_Model_Paper_List Public:https://github.com/Event-AHU/Mamba_State_Space_Model_Paper_List


第四章:Mamba理论与分析

未完待续...


第五章:总结和展望

  1. 2024年04月29日16:57:45,今天已完成环境的安装与即插即用模块实例化和相关论文的分享;在近期会充分学习Mamba后对其理论进行分享,帮助快速简要理解原文Mamba相关理论。

这篇关于深入浅出一文图解Vision Mamba(ViM)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/948143

相关文章

centos 6安装 vim

centos 安装vim 1.首先查询当前当前vim所依赖的包存在不存在.检查缺少哪个几个依赖包 [root@bogon firstCopy]# rpm -qa|grep vivimvim-common-7.4.160-5.el7.x86_64vim-enhanced-7.4.160-5.el7.x86_64vim-filesystem-7.4.160-5.el7.x86_64vim-

图解TCP三次握手|深度解析|为什么是三次

写在前面 这篇文章我们来讲解析 TCP三次握手。 TCP 报文段 传输控制块TCB:存储了每一个连接中的一些重要信息。比如TCP连接表,指向发送和接收缓冲的指针,指向重传队列的指针,当前的发送和接收序列等等。 我们再来看一下TCP报文段的组成结构 TCP 三次握手 过程 假设有一台客户端,B有一台服务器。最初两端的TCP进程都是处于CLOSED关闭状态,客户端A打开链接,服务器端

图解可观测Metrics, tracing, and logging

最近在看Gophercon大会PPT的时候无意中看到了关于Metrics,Tracing和Logging相关的一篇文章,凑巧这些我基本都接触过,也是去年后半年到现在一直在做和研究的东西。从去年的关于Metrics的goappmonitor,到今年在排查问题时脑洞的基于log全链路(Tracing)追踪系统的设计,正好是对这三个话题的实践。这不禁让我对它们的关系进行思考:Metrics和Loggi

文本编辑器-Vim

http://www.vim.org/ 简单介绍 Vim是一种高度可配置的文本编辑器,用于创建和更改任何类型的文本非常高效。它与大多数UNIX系统和苹果OS X一起被列为 “vi”。 Vim是稳定的,并且不断被开发以变得更好。 其功能包括: 1. 持久的,多级的撤消树 2. 广泛的插件系统 3. 支持数百种编程语言和文件格式 4. 强大的搜索和替换 5. 与许多工具集成 下载

Vim命令记录

2019年4月26日22:46修改 好玩网站:https://coolshell.cn/articles/5426.html http://c.biancheng.net/view/813.html vim启动进入普通模式,处于插入模式或命令行模式时只需要按Esc或者Ctrl+[即可进入普通模式。普通模式中按i(插入)或a(附加)键都可以进入插入模式,普通模式中按:进入命令行模式。命令行模

kaggle竞赛宝典 | Mamba模型综述!

本文来源公众号“kaggle竞赛宝典”,仅用于学术分享,侵权删,干货满满。 原文链接:Mamba模型综述! 型语言模型(LLMs),成为深度学习的基石。尽管取得了令人瞩目的成就,Transformers仍面临固有的局限性,尤其是在推理时,由于注意力计算的平方复杂度,导致推理过程耗时较长。 最近,一种名为Mamba的新型架构应运而生,其灵感源自经典的状态空间模型,成为构建基础模型的有力替代方案

Post-Training有多重要?一文带你了解全部细节

1. 简介 随着LLM学界和工业界日新月异的发展,不仅预训练所用的算力和数据正在疯狂内卷,后训练(post-training)的对齐和微调方法也在不断更新。InstructGPT、WebGPT等较早发布的模型使用标准RLHF方法,其中的数据管理风格和规模似乎已经过时。近来,Meta、谷歌和英伟达等AI巨头纷纷发布开源模型,附带发布详尽的论文或报告,包括Llama 3.1、Nemotron 340

Linux 中常用的 Vim 命令大全

Vim 是 Linux 系统中最常用的文本编辑器之一,因其强大的功能和轻量级的性能广泛应用于开发者和系统管理员的日常工作中。无论是代码编辑、系统配置还是日志分析,掌握 Vim 的基础和高级命令都能极大提升工作效率。本文将汇总 Vim 的常用命令及其功能,帮助你在 Linux 环境中高效操作 Vim。 一、Vim 基础介绍 Vim 是一个基于 vi 的高级文本编辑器,提供了更多强大的功能。它

vim 安装与配置教程(详细教程)

vim就是一个功能非常强大的文本编辑器,可以自己DIY的那种 ,不但可以写代码 ,还可编译 ,可以让你手不离键盘的完成鼠标的所有操作。  如果想要了解vim的的发展历史和详细解说,可以自行上网搜索,我主要是记录一下安转和配置流程以及基础的使用方法,可以简单入个门。话不多说 直接开始吧!!! 注意: 本教程是在linux系统下进行的,当然vim 还支持Windows、CentOs等系统  1、

Linux下新手如何将VIM配置成C++编程环境(可以STL自动补全)

~ 弄拉老半天,终于弄的差不多啦,果然程序员还是需要有点折腾精神啊。 首先你要安装vim,命令:sudo apt-get install vim vim它只是一个编辑器,它不是IDE(比如codeblocks),IDE相当于已经给一个房子装好啦各种东西,你只要使用就行,vim却要自己装各种东西,相当于买了一个毛坯房,自己要给房子装潢。 如何安装g++编译器可以参考我上一篇博文. 1:vi