YOLOv8改进实战 | 注意力篇 | 引入基于跨空间学习的高效多尺度注意力EMA,小目标涨点明显

本文主要是介绍YOLOv8改进实战 | 注意力篇 | 引入基于跨空间学习的高效多尺度注意力EMA,小目标涨点明显,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


前言

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8 是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。


目录

  • 一、EMA介绍
  • 二、代码实现
    • 代码目录
    • 注册模块
    • 配置yaml文件
  • 三、模型测试
  • 四、模型训练
  • 五、总结

一、EMA介绍

在这里插入图片描述

论文链接:Efficient Multi-Scale Attention Module with Cross-Spatial Learning

在这里插入图片描述

论文提出了一种新颖的高效多尺度注意力(EMA)模块。EMA模块旨在保留每个通道的信息,同时减少计算开销。它通过重塑部分通道到批次维度,并将通道雏度分组为多个子特征,使得空间语义特征在每个特征组内均匀分布。此外,EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的成对关系。

在这里插入图片描述

创新点主要包括:

  1. 高效多尺度注意力(EMA):新型的注意力机制,同时减少计算开销和保留每个通道的关键信息

  2. 通道和批次维度的重组:通过重新组织通道维度和批次维度,提高了模型处理特征的能力。

  3. 跨维度交互:模块利用跨维度的交互来捕捉像素级别的关系

  4. 全局信息编码和通道权重校准:在并行分支中编码全局信息,用于通道权重的重新校准,增强了特征表示的能力。

二、代码实现

代码目录

  • 按下面文件夹结构创建文件(相比于在原有ultralytics/nn/modules文件夹下的相关文件中直接添加便于管理
    - ultralytics- nn- extra_modules- __init__.py- attention.py- modules
    

ultralytics/nn/extra_modules/__init__.py中添加:

from .attention import *

ultralytics/nn/extra_modules/attention.py中添加:

import torch
from torch import nn__all__ = ['EMA']class EMA(nn.Module):def __init__(self, channels, factor=8):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)

注册模块

ultralytics/nn/tasks.py文件开头添加:

from ultralytics.nn.extra_modules import *

ultralytics/nn/tasks.py文件中parse_model函数添加:

elif m in {EMA}:args = [ch[f], *args]

配置yaml文件

yolov8-ema.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, EMA, []]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)- [-1, 1, EMA, []]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 23 (P5/32-large)- [-1, 1, EMA, []]- [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

三、模型测试

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLOmodel = YOLO("yolov8n-ema.yaml")  # build a new model from scratch
                   from  n    params  module                                       arguments0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]8                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]12                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]15                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]16                  -1  1       672  ultralytics.nn.extra_modules.attention.EMA   [64]17                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]18            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]19                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]20                  -1  1      2624  ultralytics.nn.extra_modules.attention.EMA   [128]21                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]22             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]23                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]24                  -1  1     10368  ultralytics.nn.extra_modules.attention.EMA   [256]25        [16, 20, 24]  1    897664  ultralytics.nn.modules.head.Detect           [80, [64, 128, 256]]
YOLOv8n-ema summary: 249 layers, 3170864 parameters, 3170848 gradients, 9.1 GFLOPs

四、模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO# Load a model
model = YOLO("yolov8n-ema.yaml")  # build a new model from scratch# Use the model
model.train(data="./mydata/data.yaml",epochs=300,batch=32,imgsz=640,workers=8,device=0,project="runs/train",name='exp')  # train the model

五、总结

  • 模型的训练具有很大的随机性,您可能需要点运气和更多的训练次数才能达到最高的 mAP。
    在这里插入图片描述

这篇关于YOLOv8改进实战 | 注意力篇 | 引入基于跨空间学习的高效多尺度注意力EMA,小目标涨点明显的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1138025

相关文章

高效管理你的Linux系统: Debian操作系统常用命令指南

《高效管理你的Linux系统:Debian操作系统常用命令指南》在Debian操作系统中,了解和掌握常用命令对于提高工作效率和系统管理至关重要,本文将详细介绍Debian的常用命令,帮助读者更好地使... Debian是一个流行的linux发行版,它以其稳定性、强大的软件包管理和丰富的社区资源而闻名。在使用

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

Golang使用minio替代文件系统的实战教程

《Golang使用minio替代文件系统的实战教程》本文讨论项目开发中直接文件系统的限制或不足,接着介绍Minio对象存储的优势,同时给出Golang的实际示例代码,包括初始化客户端、读取minio对... 目录文件系统 vs Minio文件系统不足:对象存储:miniogolang连接Minio配置Min

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

SpringBoot项目引入token设置方式

《SpringBoot项目引入token设置方式》本文详细介绍了JWT(JSONWebToken)的基本概念、结构、应用场景以及工作原理,通过动手实践,展示了如何在SpringBoot项目中实现JWT... 目录一. 先了解熟悉JWT(jsON Web Token)1. JSON Web Token是什么鬼

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06