本文给大家带来的机制是由我独家创新结合Deformable Large Kernel Attention (D-LKA) 注意力机制和DCNv3可变形卷积的全新注意力机制模块(算是二次创新),D-LKA的基本原理是结合了大卷积核和可变形卷积的注意力机制,通过采用大卷积核来模拟类似自我关注的感受野,同时避免了传统自我关注机制的高计算成本,同时我们利用DCNv3来优化其中的可变形卷积,本文内容为我独家整理。
2.1 Deformable-LKA的基本原理
2.2 大卷积核
2.3 可变形卷积DCNv3
2.4 2D和3D适应性
4.1 修改一
4.2 修改二
4.3 修改三
4.4 修改四
5.1 Deformable-LKA的yaml文件
5.2 Deformable-LKA的训练过程截图
Deformable Large Kernel Attention (D-LKA) 的基本原理是结合了大卷积核和可变形卷积的注意力机制,通过采用大卷积核来模拟类似自我关注的感受野,同时避免了传统自我关注机制的高计算成本。此外,D-LKA通过可变形卷积来灵活调整采样网格,使得模型能够更好地适应不同的数据模式。可以将其分为以下几点:
1. 大卷积核: D-LKA 使用大卷积核来捕捉图像的广泛上下文信息,模仿自我关注机制的感受野。
2. 可变形卷积: 结合可变形卷积技术,允许模型的采样网格根据图像特征灵活变形,适应不同的数据模式。
3. 2D和3D适应性: D-LKA的2D和3D版本,使其在处理不同深度的数据时表现出色。
大卷积核(Large Kernel)是一种用于捕捉图像中的广泛上下文信息的机制。它模仿自注意力(self-attention)机制的感受野,但是使用更少的参数和计算量。通过使用深度可分离的卷积(depth-wise convolution)和深度可分离的带扩张的卷积(depth-wise dilated convolution),可以有效地构造大卷积核。这种方法允许网络在较大的感受野内学习特征,同时通过减少参数数量来降低计算复杂度。在Deformable LKA中,大卷积核与可变形卷积结合使用,进一步增加了模型对复杂图像模式的适应性。
上图为变形大核注意力(Deformable Large Kernel Attention, D-LKA)模块的架构。从图中可以看出,该模块由多个卷积层组成,包括:
1. 标准的2D卷积(Conv2D)。
2. 带有偏移量的变形卷积(Deformable Convolution, Deform-DW Conv2D),允许网络根据输入特征自适应地调整其感受野。
3. 偏移场(Offsets Field)的计算,它是由一个标准卷积层生成,用于指导变形卷积层如何调整其采样位置。
4. 激活函数GELU,增加非线性。
2.3 可变形卷积DCNv3
首先我们先来介绍一个大的概念DCN全称为Deformable Convolutional Networks,翻译过来就是可变形卷积的意思,其是一种用于目标检测和图像分割的卷积神经网络模块,通过引入可变形卷积操作来提升模型对目标形变的建模能力。
2.4 2D和3D适应性
2D和3D适应性指的是Deformable Large Kernel Attention(D-LKA)技术应用于不同维度数据的能力。2D D-LKA专为处理二维图像数据设计,适用于常见的医学成像方法,如X射线或MRI中的单层切片。而3D D-LKA则扩展了这种技术,使其能够处理三维数据集,充分利用体积图像数据中的空间上下文信息。3D版本特别擅长于交叉深度数据理解,即能够在多个层面上分析和识别图像特征,这对于体积重建和更复杂的医学成像任务非常有用。
上图展示了3D和2D Deformable Large Kernel Attention(D-LKA)模型的网络架构。左侧是3D D-LKA模型,右侧是2D D-LKA模型。
1. 3D D-LKA模型(左侧):包含多个3D D-LKA块,这些块在下采样和上采样之间交替,用于深度特征学习和分辨率恢复。
2. 2D D-LKA模型(右侧):利用MaxViT块作为编码器组件,并在不同的分辨率级别上使用2D D-LKA块,通过扩展(Patch Expanding)和D-LKA注意力机制进行特征学习。
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant___all__ = ['C3_DCNv3_DLKA', 'deformable_LKA_Attention']def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0,stride_h=1, stride_w=1):_, H_, W_, _ = spatial_shapesH_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1ref_y, ref_x = torch.meshgrid(torch.linspace(# pad_h + 0.5,# H_ - pad_h - 0.5,(dilation_h * (kernel_h - 1)) // 2 + 0.5,(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,H_out,dtype=torch.float32,device=device),torch.linspace(# pad_w + 0.5,# W_ - pad_w - 0.5,(dilation_w * (kernel_w - 1)) // 2 + 0.5,(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,W_out,dtype=torch.float32,device=device))ref_y = ref_y.reshape(-1)[None] / H_ref_x = ref_x.reshape(-1)[None] / W_ref = torch.stack((ref_x, ref_y), -1).reshape(1, H_out, W_out, 1, 2)return refdef _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):_, H_, W_, _ = spatial_shapespoints_list = []x, y = torch.meshgrid(torch.linspace(-((dilation_w * (kernel_w - 1)) // 2),-((dilation_w * (kernel_w - 1)) // 2) +(kernel_w - 1) * dilation_w, kernel_w,dtype=torch.float32,device=device),torch.linspace(-((dilation_h * (kernel_h - 1)) // 2),-((dilation_h * (kernel_h - 1)) // 2) +(kernel_h - 1) * dilation_h, kernel_h,dtype=torch.float32,device=device))points_list.extend([x / W_, y / H_])grid = torch.stack(points_list, -1).reshape(-1, 1, 2). \repeat(1, group, 1).permute(1, 0, 2)grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)return griddef dcnv3_core_pytorch(input, offset, mask, kernel_h,kernel_w, stride_h, stride_w, pad_h,pad_w, dilation_h, dilation_w, group,group_channels, offset_scale):# for debug and test only,# need to use cuda version insteadinput = F.pad(input,[0, 0, pad_h, pad_h, pad_w, pad_w])N_, H_in, W_in, _ = input.shape_, H_out, W_out, _ = offset.shaperef = _get_reference_points(input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)grid = _generate_dilation_grids(input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2). \repeat(1, 1, 1, group * kernel_h * kernel_w).to(input.device)sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \offset * offset_scale / spatial_normP_ = kernel_h * kernel_wsampling_grids = 2 * sampling_locations - 1# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_ininput_ = input.view(N_, H_in * W_in, group * group_channels).transpose(1, 2). \reshape(N_ * group, group_channels, H_in, W_in)# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2sampling_grid_ = sampling_grids.view(N_, H_out * W_out, group, P_, 2).transpose(1, 2). \flatten(0, 1)# N_*group, group_channels, H_out*W_out, P_sampling_input_ = F.grid_sample(input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)mask = mask.view(N_, H_out * W_out, group, P_).transpose(1, 2). \reshape(N_ * group, 1, H_out * W_out, P_)output = (sampling_input_ * mask).sum(-1).view(N_,group * group_channels, H_out * W_out)return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()class to_channels_first(nn.Module):def __init__(self):super().__init__()def forward(self, x):return x.permute(0, 3, 1, 2)class to_channels_last(nn.Module):def __init__(self):super().__init__()def forward(self, x):return x.permute(0, 2, 3, 1)def build_norm_layer(dim,norm_layer,in_format='channels_last',out_format='channels_last',eps=1e-6):layers = []if norm_layer == 'BN':if in_format == 'channels_last':layers.append(to_channels_first())layers.append(nn.BatchNorm2d(dim))if out_format == 'channels_last':layers.append(to_channels_last())elif norm_layer == 'LN':if in_format == 'channels_first':layers.append(to_channels_last())layers.append(nn.LayerNorm(dim, eps=eps))if out_format == 'channels_first':layers.append(to_channels_first())else:raise NotImplementedError(f'build_norm_layer does not support {norm_layer}')return nn.Sequential(*layers)def build_act_layer(act_layer):if act_layer == 'ReLU':return nn.ReLU(inplace=True)elif act_layer == 'SiLU':return nn.SiLU(inplace=True)elif act_layer == 'GELU':return nn.GELU()raise NotImplementedError(f'build_act_layer does not support {act_layer}')def _is_power_of_2(n):if (not isinstance(n, int)) or (n < 0):raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))return (n & (n - 1) == 0) and n != 0class CenterFeatureScaleModule(nn.Module):def forward(self,query,center_feature_scale_proj_weight,center_feature_scale_proj_bias):center_feature_scale = F.linear(query,weight=center_feature_scale_proj_weight,bias=center_feature_scale_proj_bias).sigmoid()return center_feature_scaleclass DCNv3_pytorch(nn.Module):def __init__(self,channels=64,kernel_size=3,dw_kernel_size=None,stride=1,pad=1,dilation=1,group=4,offset_scale=1.0,act_layer='GELU',norm_layer='LN',center_feature_scale=False):"""DCNv3 Module:param channels:param kernel_size:param stride:param pad:param dilation:param group:param offset_scale:param act_layer:param norm_layer"""super().__init__()if channels % group != 0:raise ValueError(f'channels must be divisible by group, but got {channels} and {group}')_d_per_group = channels // groupdw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementationif not _is_power_of_2(_d_per_group):warnings.warn("You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 ""which is more efficient in our CUDA implementation.")self.offset_scale = offset_scaleself.channels = channelsself.kernel_size = kernel_sizeself.dw_kernel_size = dw_kernel_sizeself.stride = strideself.dilation = dilationself.pad = padself.group = groupself.group_channels = channels // groupself.offset_scale = offset_scaleself.center_feature_scale = center_feature_scaleself.dw_conv = nn.Sequential(nn.Conv2d(channels,channels,kernel_size=dw_kernel_size,stride=1,padding=(dw_kernel_size - 1) // 2,groups=channels),build_norm_layer(channels,norm_layer,'channels_first','channels_last'),build_act_layer(act_layer))self.offset = nn.Linear(channels,group * kernel_size * kernel_size * 2)self.mask = nn.Linear(channels,group * kernel_size * kernel_size)self.input_proj = nn.Linear(channels, channels)self.output_proj = nn.Linear(channels, channels)self._reset_parameters()if center_feature_scale:self.center_feature_scale_proj_weight = nn.Parameter(torch.zeros((group, channels), dtype=torch.float))self.center_feature_scale_proj_bias = nn.Parameter(torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))self.center_feature_scale_module = CenterFeatureScaleModule()def _reset_parameters(self):constant_(self.offset.weight.data, 0.)constant_(self.offset.bias.data, 0.)constant_(self.mask.weight.data, 0.)constant_(self.mask.bias.data, 0.)xavier_uniform_(self.input_proj.weight.data)constant_(self.input_proj.bias.data, 0.)xavier_uniform_(self.output_proj.weight.data)constant_(self.output_proj.bias.data, 0.)def forward(self, input):""":param query (N, H, W, C):return output (N, H, W, C)"""input = input.permute(0, 2, 3, 1)N, H, W, _ = input.shapex = self.input_proj(input)x_proj = xx1 = input.permute(0, 3, 1, 2)x1 = self.dw_conv(x1)offset = self.offset(x1)mask = self.mask(x1).reshape(N, H, W, self.group, -1)mask = F.softmax(mask, -1).reshape(N, H, W, -1)x = dcnv3_core_pytorch(x, offset, mask,self.kernel_size, self.kernel_size,self.stride, self.stride,self.pad, self.pad,self.dilation, self.dilation,self.group, self.group_channels,self.offset_scale)if self.center_feature_scale:center_feature_scale = self.center_feature_scale_module(x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channelscenter_feature_scale = center_feature_scale[..., None].repeat(1, 1, 1, 1, self.channels // self.group).flatten(-2)x = x * (1 - center_feature_scale) + x_proj * center_feature_scalex = self.output_proj(x).permute(0, 3, 1, 2)return xclass DeformConv(nn.Module):def __init__(self, in_channels, groups, kernel_size=(3, 3), padding=1, stride=1, dilation=1, bias=True):super(DeformConv, self).__init__()self.deform_conv = DCNv3_pytorch(in_channels)def forward(self, x):out = self.deform_conv(x)return outclass deformable_LKA(nn.Module):def __init__(self, dim):super().__init__()self.conv0 = DeformConv(dim, kernel_size=(5, 5), padding=2, groups=dim)self.conv_spatial = DeformConv(dim, kernel_size=(7, 7), stride=1, padding=9, groups=dim, dilation=3)self.conv1 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0(x)attn = self.conv_spatial(attn)attn = self.conv1(attn)return u * attnclass deformable_LKA_Attention(nn.Module):def __init__(self, d_model):super().__init__()self.proj_1 = nn.Conv2d(d_model, d_model, 1)self.activation = nn.GELU()self.spatial_gating_unit = deformable_LKA(d_model)self.proj_2 = nn.Conv2d(d_model, d_model, 1)def forward(self, x):shorcut = x.clone()x = self.proj_1(x)x = self.activation(x)x = self.spatial_gating_unit(x)x = self.proj_2(x)x = x + shorcutreturn xdef autopad(k, p=None, d=1): # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU() # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class Bottleneck(nn.Module):# Standard bottleneckdef __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_, c2, 3, 1, g=g)self.Dattention = deformable_LKA_Attention(c2)self.add = shortcut and c1 == c2def forward(self, x):return x + self.Dattention(self.cv2(self.cv1(x))) if self.add else self.Dattention(self.cv2(self.cv1(x)))class C3_DCNv3_DLKA(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))def forward(self, x):return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))if __name__ == "__main__":# Generating Sample imageimage_size = (1, 64, 224, 224)image = torch.rand(*image_size)# Modelmodel = C3_DCNv3_DLKA(64, 64)out = model(image)print(out.size())
4.1 修改一
4.2 修改二
4.3 修改三
4.4 修改四
5.1 Deformable-LKA的yaml文件
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:- [10,13, 16,30, 33,23] # P3/8- [30,61, 62,45, 59,119] # P4/16- [116,90, 156,198, 373,326] # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]], # 9]# YOLOv5 v6.0 head
head: [[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, "nearest"]],[[-1, 6], 1, Concat, [1]], # cat backbone P4[-1, 3, C3_DCNv3_DLKA, [512, False]], # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, "nearest"]],[[-1, 4], 1, Concat, [1]], # cat backbone P3[-1, 3, C3_DCNv3_DLKA, [256, False]], # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]], # cat head P4[-1, 3, C3_DCNv3_DLKA, [512, False]], # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]], # cat head P5[-1, 3, C3_DCNv3_DLKA, [1024, False]], # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)]
5.2 Deformable-LKA的训练过程截图
