有效涨点,增强型 YOLOV8 与多尺度注意力特征融合,附代码,详细步骤

本文主要是介绍有效涨点,增强型 YOLOV8 与多尺度注意力特征融合,附代码,详细步骤,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

摘要

结构图 

 原理

代码实现

添加ymal文件 

实验结果

可接论文指导----------> v jiabei-545

完整代码(失效+ -----------👆 )

执行程序流程


摘要

在本实验中,我们通过将双层路由注意(BRA)、广义特征金字塔网络(GFPN)和第四检测头结合到YOLOv8中,开发了一种新颖的BGF-YOLO架构。 BGF-YOLO 包含注意力机制,可以更多地关注重要特征,特征金字塔网络可以通过将高级语义特征与空间细节合并来丰富特征表示。此外,我们研究了不同的注意力机制和特征融合、检测头架构对检测准确性的影响。实验结果表明,与YOLOv8x相比,BGF-YOLO的mAP50增加了4.7%。

我们提出了一种称为 BGF-YOLO 的新模型,它通过结合双层路由注意(BRA)、广义 FPN(GFPN)和第四检测头来增强 YOLOv8 的检测性能。这项工作的贡献总结如下:

1、我们用基于GFPN的结构化特征融合网络重建了YOLOv8的原始颈部部分,以促进不同级别的有效特征融合。

2、我们利用 BRA 进行动态和稀疏注意力机制,以关注更显着的特征并减少特征冗余。

3、我们添加了第四个检测头来丰富锚框的尺度并优化检测的回归损失。

结构图 
BGF-YOLO的架构基于YOLOv8,并结合了新的彩色模块Bi-level Routing Attention(BRA)和Cross Stage Partial DenseNet(CSP)。 Conv、C2f(快捷方式)、SPPF、Concat、Upsample 和 Detect 是原始 YOLOv8 架构中的现有模块。
 原理

广义FPN(GFPN)采用密集链接和皇后融合的结构来产生更好融合的特征,并使用concat操作而不是求和来执行特征融合以减少信息丢失。 AFPN在渐近过程中使用自适应空间融合,首先融合两个低级特征,然后融合更高级别的特征,最后融合顶级特征,以增强关键级别的重要性并减轻来自不同对象的矛盾信息的影响。

注意力机制最初是为了权衡特定特征相对于其他特征的重要性而提出的。在计算机视觉背景下,有五种注意力机制在提高目标检测性能方面具有巨大潜力:挤压和激励(SE)、CBAM、高效通道注意力(ECA)、CA、感受野注意力(RFA)和BRA。它们之间的区别在于SE和ECA属于通道注意力,RFA和BRA处理空间注意力,CBAM和CA促进通道和空间注意力。 SE 是通过显式建模卷积特征通道之间的相互依赖性来自适应地重新校准通道特征响应。 ECA仅捕获本地通道相互依赖性,而不依赖全局统计来减少计算需求。 RFA的优点是提供有效的注意力权重来实现卷积核参数共享。 BRA 是一种动态的、查询感知的稀疏注意力机制,它以内容感知的方式为每个查询启用最相关的键/值标记的一小部分。我们通过采用 BRA 注意模块改进了所提出的基于 GFPN 的特征融合结构,以实现有效的多级特征融合,同时避免跨特征图的冗余信息。动态稀疏注意力可以通过在整合不同尺度的特征图时应用每个通道的权重分布和空间位置来减少冗余特征信息并提高模型的检测精度。我们在特征融合过程中将BRA模块放置在Conv或Upsample模块后面,使模型在特征提取后只关注特定区域。为了进一步避免信息丢失,CSP 模块中的跳过连接使得底层特征图的知识能够在后续层中重用。 BRA 的目标是在更广泛的区域层面消除大多数不相关的键值对输入,只留下少数相关区域。以特征图作为输入,BRA 首先将其分割成各个区域,并通过线性变换导出查询、键和值。将查询和键的区域级关系输入到邻接矩阵中以构造有向图并查明特定键值对的关联。这实质上确定了每个指定区域应涉及哪些领域。最后,利用区域到区域路由索引矩阵在各个令牌之间执行多头自注意力。通过多头自注意力的双层路由优化,更多地关注特征图的ROI部分,从而提高模型检测的能力。

代码实现
class CSPStage(nn.Module):def __init__(self,ch_in,ch_out,n,block_fn='BasicBlock_3x3_Reverse',ch_hidden_ratio=1.0,act='silu',spp=False):super(CSPStage, self).__init__()split_ratio = 2ch_first = int(ch_out // split_ratio)ch_mid = int(ch_out - ch_first)self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)self.convs = nn.Sequential()next_ch_in = ch_midfor i in range(n):if block_fn == 'BasicBlock_3x3_Reverse':self.convs.add_module(str(i),BasicBlock_3x3_Reverse(next_ch_in,ch_hidden_ratio,ch_mid,act=act,shortcut=True))else:raise NotImplementedErrorif i == (n - 1) // 2 and spp:self.convs.add_module('spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))next_ch_in = ch_midself.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)def forward(self, x):y1 = self.conv1(x)y2 = self.conv2(x)mid_out = [y1]for conv in self.convs:y2 = conv(y2)mid_out.append(y2)y = torch.cat(mid_out, axis=1)y = self.conv3(y)return y
class BiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights"""def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=True):super().__init__()# local attention settingself.dim = dimself.n_win = n_win  # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dimassert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# routerassert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else:  # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':# assert self.kv_downsample_ratio is not None# assert self.kv_downsample_kenel is not None# TODO: fracpool# 1. kernel size should be input size dependent# 2. there is a random factor, need to avoid independent sampling for k and vraise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':# TODO: need to consider the case where k != v so that need two downsample modulesraise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')# softmax for local attentionself.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_paddef forward(self, x, ret_attn_mask=False):"""x: NHWC tensorReturn:NHWC tensor"""x = rearrange(x, "n c h w -> n h w c")# NOTE: use padding for semantic segmentation###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()assert H%self.n_win == 0 and W%self.n_win == 0 ##################################################### patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size#print(x.shape,self.n_win)x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)# print(x.shape,self.n_win)#################qkv projection#################### q: (n, p^2, w, w, c_qk)# kv: (n, p^2, w, w, c_qk+c_v)# NOTE: separte kv if there were memory leak issue caused by gatherq, kv = self.qkv(x)# pixel-wise qkv# q_pix: (n, p^2, w^2, c_qk)# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)############ gather q dependent k/v #################r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensorskv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)######### do attention as normal ####################k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)# param-free multihead attentionattn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)attn_weight = self.attn_act(attn_weight)out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,h=H//self.n_win, w=W//self.n_win)out = out + lepe# output linearout = self.wo(out)# NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return rearrange(out, "n h w c -> n c h w")
添加ymal文件 
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license
# BGF object detection model with P3-P6 outputs.# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, Conv, [512, 1, 1]]- [ 6, 1, Conv, [ 512, 3, 2 ] ]- [-1, 1, BiLevelRoutingAttention, [8 ]]- [ [ -1, 10 ], 1, Concat, [ 1 ] ]- [ -1, 3, CSPStage, [ 512 ] ] # 14      20*20- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [4, 1, Conv, [256, 3, 2]]- [-1, 1, BiLevelRoutingAttention, [8 ]]- [ [ 15, -1, 6 ], 1, Concat, [ 1 ] ]- [-1, 3, CSPStage, [512]] # 19     40*40- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ -1, 1, BiLevelRoutingAttention, [ 8 ] ]- [ [ -1, 4 ], 1, Concat, [ 1 ] ]- [ -1, 3, CSPStage, [ 256 ] ]# 23   80*80- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ -1, 1, BiLevelRoutingAttention, [ 8 ] ]- [ [ -1, 2 ], 1, Concat, [ 1 ] ]- [ -1, 3, CSPStage, [ 128 ] ] # 27   160*160- [ -1, 1, Conv, [ 128, 3, 2 ] ]- [-1, 1, BiLevelRoutingAttention, [8 ]]- [ [ -1, 23  ], 1, Concat, [ 1 ] ]  # cat head P4- [ -1, 3, CSPStage, [ 256 ] ]  # 31     80*80- [ -1, 1, Conv, [ 256, 3, 2 ] ]- [ -1, 1, BiLevelRoutingAttention, [ 8 ] ]- [ [ -1, 19 ], 1, Concat, [ 1 ] ]  # cat head P4- [ -1, 3, CSPStage, [ 512 ] ]  # 35    40*40- [ 19, 1, Conv, [ 256, 3, 2 ] ] #   20*20- [ 35, 1, Conv, [ 256, 3, 2 ] ] #   20*20- [ [14, 36, -1 ], 1, Concat, [ 1 ] ]- [ -1, 3, CSPStage, [ 1024 ] ]  # 39    20*20- [ [ 27, 31, 35, 39 ], 1, Detect, [ nc ] ]  # Detect(P3, P4, P5, P6)
实验结果

通过优化 GFPN 特征融合结构、BRA 注意力机制以及在 BGF-YOLO 模型中添加检测头,YOLOv8 的目标检测能力得到了显着增强。这些修改能够实现不同级别和更丰富尺度的加权特征融合,并产生具有动态聚焦机制的高质量锚框。此外,BGF-YOLO 中提出的模块优于其他替代技术,对不同特征融合结构、注意力机制和回归损失。我们提出的 BGF-YOLO 成为检测数据集 上当前最先进的模型。 

可接论文指导----------> v jiabei-545
完整代码(失效+ -----------👆 )

链接: https://pan.baidu.com/s/1K5I13pJdsKvGyMub5c1HKw?pwd=zk88 提取码: zk88 

执行程序流程

pip install -r requirements.txt  # install

python yolo/bgf/detect/train.py 

现在就能顺利的执行了 

ok,快去试试你的实验有没有涨点!

这篇关于有效涨点,增强型 YOLOV8 与多尺度注意力特征融合,附代码,详细步骤的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/749065

相关文章

将Mybatis升级为Mybatis-Plus的详细过程

《将Mybatis升级为Mybatis-Plus的详细过程》本文详细介绍了在若依管理系统(v3.8.8)中将MyBatis升级为MyBatis-Plus的过程,旨在提升开发效率,通过本文,开发者可实现... 目录说明流程增加依赖修改配置文件注释掉MyBATisConfig里面的Bean代码生成使用IDEA生

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

springboot循环依赖问题案例代码及解决办法

《springboot循环依赖问题案例代码及解决办法》在SpringBoot中,如果两个或多个Bean之间存在循环依赖(即BeanA依赖BeanB,而BeanB又依赖BeanA),会导致Spring的... 目录1. 什么是循环依赖?2. 循环依赖的场景案例3. 解决循环依赖的常见方法方法 1:使用 @La

使用C#代码在PDF文档中添加、删除和替换图片

《使用C#代码在PDF文档中添加、删除和替换图片》在当今数字化文档处理场景中,动态操作PDF文档中的图像已成为企业级应用开发的核心需求之一,本文将介绍如何在.NET平台使用C#代码在PDF文档中添加、... 目录引言用C#添加图片到PDF文档用C#删除PDF文档中的图片用C#替换PDF文档中的图片引言在当

Linux系统中卸载与安装JDK的详细教程

《Linux系统中卸载与安装JDK的详细教程》本文详细介绍了如何在Linux系统中通过Xshell和Xftp工具连接与传输文件,然后进行JDK的安装与卸载,安装步骤包括连接Linux、传输JDK安装包... 目录1、卸载1.1 linux删除自带的JDK1.2 Linux上卸载自己安装的JDK2、安装2.1

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

用js控制视频播放进度基本示例代码

《用js控制视频播放进度基本示例代码》写前端的时候,很多的时候是需要支持要网页视频播放的功能,下面这篇文章主要给大家介绍了关于用js控制视频播放进度的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言html部分:JavaScript部分:注意:总结前言在javascript中控制视频播放

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN