改进YOLOv5:添加EMA注意力机制

2023-10-14 16:20

本文主要是介绍改进YOLOv5:添加EMA注意力机制,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 新建EMA.py文件
  • 修改yolo.py文件
    • 1.导入EMA.py
    • 2.修改parse_model
  • 修改yaml文件(yolov5s为例)
  • 参考


前言

本文主要介绍一种在YOLOv5-7.0中添加EMA注意力机制的方法。EMA注意力机制原论文地址,有关EMA注意力机制的解读可参考文章。

新建EMA.py文件

在yolov5的models文件中新建一个名为EMA.py文件,将下述代码复制到EMA.py文件中并保存。

import torch
from torch import nnclass EMA(nn.Module):def __init__(self, channels, factor=8):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)

修改yolo.py文件

1.导入EMA.py

在yolo.py文件开头导入EMA.py,代码如下:

from models.EMA import EMA

代码放在yolo.py位置如下图所示:
在这里插入图片描述

2.修改parse_model

这里主要是添加通道参数,再添加一个elif,把EMA添加进去,代码如下:

 elif m is EMA:   args = [ch[f], *args]

添加上述代码的位置可参考下图:
在这里插入图片描述


修改yaml文件(yolov5s为例)

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, EMA, [8]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 15], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

上述代码将EMA注意力机制模块加在backbone层中最后C3模块后面,SPPF模块前面,仅供参考,具体添加位置要根据个人数据集的不同合理的添加。

[-1, 1, EMA, [8]], #-1代表连接上一层通道数,1是个数,8是EMA所需的参数(factor=8)

说明:因为在yolo.py文件parse_model函数中修改了通道参数,因此在yaml文件中无需添加通道参数,只需添加EMA函数所需的其他参数。在backbone中添加一层注意力机制模块,因此后续的层数都要加一,在head层中做如下改动。

[[-1, 15], 1, Concat, [1]],  #未改动前的第14层,在经过上述改动后改为15
[[-1, 11], 1, Concat, [1]],  #未改动前的第10层,在记过上述改动后改为11
[[18, 21, 24], 1, Detect, [nc, anchors]],  #17,20,23层改为18,21,24

运行train.py文件可以在输出终端窗口看到上图网络结构,可以看到在第9层已经成功添加EMA注意力机制模块。

                from  n    params  module                                  arguments                     0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]              1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                2                -1  1     18816  models.common.C3                        [64, 64, 1]                   3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               4                -1  2    115712  models.common.C3                        [128, 128, 2]                 5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              6                -1  3    625152  models.common.C3                        [256, 256, 3]                 7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              8                -1  1   1182720  models.common.C3                        [512, 512, 1]                 9                -1  1     41216  models.EMA.EMA                          [512, 8]                      10                -1  1    656896  models.common.SPPF                      [512, 512, 5]                 11                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]              12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          13           [-1, 6]  1         0  models.common.Concat                    [1]                           14                -1  1    361984  models.common.C3                        [512, 256, 1, False]          15                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              16                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          17           [-1, 4]  1         0  models.common.Concat                    [1]                           18                -1  1     90880  models.common.C3                        [256, 128, 1, False]          19                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              20          [-1, 15]  1         0  models.common.Concat                    [1]                           21                -1  1    296448  models.common.C3                        [256, 256, 1, False]          22                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]              23          [-1, 11]  1         0  models.common.Concat                    [1]                           24                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          25      [18, 21, 24]  1     16182  models.yolo.Detect                      [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
YOLOv5sEMA summary: 222 layers, 7063542 parameters, 7063542 gradients, 16.2 GFLOPs

参考

https://www.bilibili.com/video/BV1s84y1775U/?spm_id_from=333.788&vd_source=f83457e2adc10b543ae4c742fba1e3b2
https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/131347981

这篇关于改进YOLOv5:添加EMA注意力机制的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/211676

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

基于 YOLOv5 的积水检测系统:打造高效智能的智慧城市应用

在城市发展中,积水问题日益严重,特别是在大雨过后,积水往往会影响交通甚至威胁人们的安全。通过现代计算机视觉技术,我们能够智能化地检测和识别积水区域,减少潜在危险。本文将介绍如何使用 YOLOv5 和 PyQt5 搭建一个积水检测系统,结合深度学习和直观的图形界面,为用户提供高效的解决方案。 源码地址: PyQt5+YoloV5 实现积水检测系统 预览: 项目背景

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,

如何通俗理解注意力机制?

1、注意力机制(Attention Mechanism)是机器学习和深度学习中一种模拟人类注意力的方法,用于提高模型在处理大量信息时的效率和效果。通俗地理解,它就像是在一堆信息中找到最重要的部分,把注意力集中在这些关键点上,从而更好地完成任务。以下是几个简单的比喻来帮助理解注意力机制: 2、寻找重点:想象一下,你在阅读一篇文章的时候,有些段落特别重要,你会特别注意这些段落,反复阅读,而对其他部分

【Tools】大模型中的注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 在大模型中,注意力机制是一种重要的技术,它被广泛应用于自然语言处理领域,特别是在机器翻译和语言模型中。 注意力机制的基本思想是通过计算输入序列中各个位置的权重,以确

FreeRTOS内部机制学习03(事件组内部机制)

文章目录 事件组使用的场景事件组的核心以及Set事件API做的事情事件组的特殊之处事件组为什么不关闭中断xEventGroupSetBitsFromISR内部是怎么做的? 事件组使用的场景 学校组织秋游,组长在等待: 张三:我到了 李四:我到了 王五:我到了 组长说:好,大家都到齐了,出发! 秋游回来第二天就要提交一篇心得报告,组长在焦急等待:张三、李四、王五谁先写好就交谁的

UVM:callback机制的意义和用法

1. 作用         Callback机制在UVM验证平台,最大用处就是为了提高验证平台的可重用性。在不创建复杂的OOP层次结构前提下,针对组件中的某些行为,在其之前后之后,内置一些函数,增加或者修改UVM组件的操作,增加新的功能,从而实现一个环境多个用例。此外还可以通过Callback机制构建异常的测试用例。 2. 使用步骤         (1)在UVM组件中内嵌callback函

一种改进的red5集群方案的应用、基于Red5服务器集群负载均衡调度算法研究

转自: 一种改进的red5集群方案的应用: http://wenku.baidu.com/link?url=jYQ1wNwHVBqJ-5XCYq0PRligp6Y5q6BYXyISUsF56My8DP8dc9CZ4pZvpPz1abxJn8fojMrL0IyfmMHStpvkotqC1RWlRMGnzVL1X4IPOa_  基于Red5服务器集群负载均衡调度算法研究 http://ww