【目标检测实验系列】YOLOv5模型改进:融合混合注意力机制CBAM,关注通道和空间特征,助力模型高效涨点!(内含源代码,超详细改进代码流程)

本文主要是介绍【目标检测实验系列】YOLOv5模型改进:融合混合注意力机制CBAM,关注通道和空间特征,助力模型高效涨点!(内含源代码,超详细改进代码流程),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

       自我介绍:本人硕士期间全程放养,目前成果:一篇北大核心CSCD录用,两篇中科院三区已见刊,一篇中科院四区在投。如何找创新点,如何放养过程厚积薄发,如何写中英论文,找期刊等等。本人后续会以自己实战经验详细写出来,还请大家能够点个关注和赞,收藏一下,谢谢大家。

1. 文章主要内容

       本篇博客主要涉及混合(通道角度与空间角度)注意力机制CBAM融合到YOLOv5模型中。(通读本篇博客需要7分钟左右的时间)

2. 详细代码改进流程

2.1 CBAM源代码(大家自己创建CBAM.py文件)

       注意,博主在CBAM源码当中添加了C3与CBAM结合的代码,还有main函数的测试案例,不影响CBAM的单独使用。

import numpy as np
import torch
from torch import nn
from torch.nn import initfrom models.common import Bottleneck, Convclass ChannelAttention(nn.Module):def __init__(self, channel, reduction=16):super().__init__()self.maxpool = nn.AdaptiveMaxPool2d(1)self.avgpool = nn.AdaptiveAvgPool2d(1)self.se = nn.Sequential(nn.Conv2d(channel, channel // reduction, 1, bias=False),nn.ReLU(),nn.Conv2d(channel // reduction, channel, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):max_result = self.maxpool(x)avg_result = self.avgpool(x)max_out = self.se(max_result)avg_out = self.se(avg_result)output = self.sigmoid(max_out + avg_out)return outputclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)self.sigmoid = nn.Sigmoid()def forward(self, x):max_result, _ = torch.max(x, dim=1, keepdim=True)avg_result = torch.mean(x, dim=1, keepdim=True)result = torch.cat([max_result, avg_result], 1)output = self.conv(result)output = self.sigmoid(output)return outputclass CBAMBlock(nn.Module):def __init__(self, channel=512, reduction=16, kernel_size=7):super().__init__()self.ca = ChannelAttention(channel=channel, reduction=reduction)self.sa = SpatialAttention(kernel_size=kernel_size)def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()out = x * self.ca(x)out = out * self.sa(out)return outclass C3CBAM(nn.Module):def __init__(self, c1, c2, n=1, shortcut=True, g=1,e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion #iscyysuper(C3CBAM, self).__init__()c_ = int(c2 * e)  # hidden channelsself.cbam = CBAMBlock(c1)self.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)# self.m = nn.Sequential(*[CB2d(c_) for _ in range(n)])self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])def forward(self, x):out = torch.cat((self.m(self.cv1(self.cbam(x))), self.cv2(self.cbam(x))), dim=1)out = self.cv3(out)return outif __name__ == '__main__':input = torch.randn(50, 512, 7, 7)cbam = C3CBAM(512, 512)output = cbam(input)print(output.shape)

       需要注意到: 源代码CBAMBlock类只需要传入一个输入的通道数channel,与YOLOv5的C3结构融合后,则C3CBAM需要传入输入和输出通道数,但大家仔细发现在C3CBAM的这行代码self.cbam =CBAMBlock(c1),实际的CBAM也只是需要传入输入的通道数即可。大家可以通过main函数进行测试。另外,在C3CBAM中,其中cv1和cv2方法里面的参数x都先通过了cbam注意力机制,这里大家可以自定义的设置。

2.2 建立一个yolov5-cbam.yaml文件

       注意到,这里博主直接使用C3CBAM代替Backbone部分的四个C3结构,另外注意nc改为自己数据集的类别数。当然,CBAM结构可以自由的放到网络当中的任何结构,但需要特别注意放了之后层次的更替问题,如有不懂,可以查看我之前写的一篇博客(以及评论区注意点):【目标检测实验系列】通过全局上下文注意力机制Global Context Block(GC)融合到YOLOv5案例,吃透简单即插即用注意力机制代码修改要点,举一反三!(超详细改进代码流程)

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 4  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8  小目标- [30,61, 62,45, 59,119]  # P4/16 中目标- [116,90, 156,198, 373,326]  # P5/32  大目标# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  output_channel, kernel_size, stride, padding[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3CBAM, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3CBAM, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3CBAM, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3CBAM, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.3 将C3CBAM引入到yolo.py文件中

       在下图的位置处,引入相关的类即可。
在这里插入图片描述

3.4 修改train.py启动文件

       修改配置文件为yolov5-cbam.yaml即可,如下图所示:
在这里插入图片描述

4. 总结

       本篇博客主要介绍了CBAM融合到YOLOv5模型,通过关注通道和空间特征,助力模型高效涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!

这篇关于【目标检测实验系列】YOLOv5模型改进:融合混合注意力机制CBAM,关注通道和空间特征,助力模型高效涨点!(内含源代码,超详细改进代码流程)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/569361

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

Security OAuth2 单点登录流程

单点登录(英语:Single sign-on,缩写为 SSO),又译为单一签入,一种对于许多相互关连,但是又是各自独立的软件系统,提供访问控制的属性。当拥有这项属性时,当用户登录时,就可以获取所有系统的访问权限,不用对每个单一系统都逐一登录。这项功能通常是以轻型目录访问协议(LDAP)来实现,在服务器上会将用户信息存储到LDAP数据库中。相同的,单一注销(single sign-off)就是指

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

高效+灵活,万博智云全球发布AWS无代理跨云容灾方案!

摘要 近日,万博智云推出了基于AWS的无代理跨云容灾解决方案,并与拉丁美洲,中东,亚洲的合作伙伴面向全球开展了联合发布。这一方案以AWS应用环境为基础,将HyperBDR平台的高效、灵活和成本效益优势与无代理功能相结合,为全球企业带来实现了更便捷、经济的数据保护。 一、全球联合发布 9月2日,万博智云CEO Michael Wong在线上平台发布AWS无代理跨云容灾解决方案的阐述视频,介绍了

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来