yolov8添加注意力机制模块-CBAM

2024-02-25 11:52

本文主要是介绍yolov8添加注意力机制模块-CBAM,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

修改

  1. 在tasks.py(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/tasks.py)文件中,引入CBAM模块。因为yolov8源码中已经包含CBAM模块,在conv.py文件中(路径:ultralytics-main/ultralytics-main - attention/ultralytics/nn/modules/conv.py),这里就就用自己写了。
  2. 修改tasks.py文件,搜索parse_model。在指定位置添加代码。
            elif m is CBAM:  # todo 源码修改 (增加了elif)"""ch[f]:上一层的args[0]:第0个参数c1:输入通道数c2:输出通道数"""c1, c2 = ch[f], args[0]# print("ch[f]:",ch[f])# print("args[0]:",args[0])# print("args:",args)# print("c1:",c1)# print("c2:",c2)if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(c2 * width, 8)args = [c1, *args[1:]]

    3.修改yolov8.yaml文件位置(ultralytics-main/ultralytics-main - attention/ultralytics/cfg/models/v8/yolov8.yaml)。修改head模块,修改的内容如下图。

        4.测试打印网络。已经添加成功。

分析

一般来说,注意力机制通常被分为以下基本四大类:

通道注意力 Channel Attention

空间注意力机制 Spatial Attention

时间注意力机制 Temporal Attention

分支注意力机制 Branch Attention

CBAM:通道注意力和空间注意力的集成者

源码解读

这段代码是对通道的注意力。首先经过自适应平均池化层,它会对每个输入通道的空间维度进行全局平均池化,并输出一个具有空间大小为 1x1 的特征图。然后是一个卷积操作,这相当于是对每个通道进行独立的全连接层变换,因为卷积核大小为1。

最后经过Sigmoid函数,将卷积层的输出转换为权重因子,范围在(0, 1)最后,这些权重因子与原始输入x逐元素相乘,以得到加权后的特征图,这一操作实现了注意力机制,允许模型专注于更有信息量的通道。

class ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, channels: int) -> None:"""Initializes the class and sets the basic configurations and instance variables required."""super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:"""Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""return x * self.act(self.fc(self.pool(x)))

下面是一个空间注意力模块,旨在通过对输入特征图加权来强调或抑制某些空间区域。空间注意力通常用于强调图像的重要部分并抑制不重要的部分。

self.cv1 是一个卷积层,有两个输入通道,一个输出通道,和可选的 kernel_size 与 padding。由于 bias=False,这个卷积层不会有偏置参数。两个输入通道对应于输入特征图的均值和最大值。

forward中

  1. torch.mean(x, 1, keepdim=True) 计算输入张量 x 每个样本的通道维度的均值,keepdim=True 表示保持输出张量的维度不变。

  2. torch.max(x, 1, keepdim=True)[0] 计算输入张量 x 每个样本的通道维度的最大值,[0] 是因为 torch.max 返回一个元组,包含最大值和相应的索引。

  3. torch.cat([avg_out, max_out], 1) 将均值和最大值沿通道维度拼接起来,这样每个空间位置都有两个通道:其均值和最大值。

  4. self.cv1(x_cat) 对拼接的结果应用 1x2 卷积,生成一个单通道的特征图,该特征图对应于每个空间位置的注意力权重。

  5. self.act(...) 应用 Sigmoid 激活函数将注意力权重映射到 (0, 1) 范围内。

  6. x * scale 将原始输入 x 与计算得到的空间注意力权重相乘,这样每个空间位置的特征值都会根据其重要性加权,实现了特征重标定。

最终,forward 方法返回的是加权后的输入特征图(对特征图的每个元素值×权值),它突出了输入中更重要的空间区域。

class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

下面就是CBAM,是上面两个模块的组合,通道注意力和空间注意力。通道注意力专注于哪些通道更重要,而空间注意力则集中在输入特征图中的哪些空间位置更重要。

  • 输入 x 首先通过 self.channel_attention,这个步骤会重新调整每个通道的重要性。
  • 然后,调整通道重要性后的特征图 x 通过 self.spatial_attention,这个步骤会重新调整特征图中每个位置的重要性。
  • 最终,这两个注意力机制的结果被串联起来,形成了最终的输出。

这种结构可以提高网络对于输入特征的逐通道和逐空间位置的重要性评估能力,进而可能提高模型的性能。

class CBAM(nn.Module):"""Convolutional Block Attention Module."""def __init__(self, c1, kernel_size=7):"""Initialize CBAM with given input channel (c1) and kernel size."""super().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""Applies the forward pass through C1 module."""return self.spatial_attention(self.channel_attention(x))

这篇关于yolov8添加注意力机制模块-CBAM的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/745394

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,

如何通俗理解注意力机制?

1、注意力机制(Attention Mechanism)是机器学习和深度学习中一种模拟人类注意力的方法,用于提高模型在处理大量信息时的效率和效果。通俗地理解,它就像是在一堆信息中找到最重要的部分,把注意力集中在这些关键点上,从而更好地完成任务。以下是几个简单的比喻来帮助理解注意力机制: 2、寻找重点:想象一下,你在阅读一篇文章的时候,有些段落特别重要,你会特别注意这些段落,反复阅读,而对其他部分

【Tools】大模型中的注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 在大模型中,注意力机制是一种重要的技术,它被广泛应用于自然语言处理领域,特别是在机器翻译和语言模型中。 注意力机制的基本思想是通过计算输入序列中各个位置的权重,以确

Jenkins构建Maven聚合工程,指定构建子模块

一、设置单独编译构建子模块 配置: 1、Root POM指向父pom.xml 2、Goals and options指定构建模块的参数: mvn -pl project1/project1-son -am clean package 单独构建project1-son项目以及它所依赖的其它项目。 说明: mvn clean package -pl 父级模块名/子模块名 -am参数

寻迹模块TCRT5000的应用原理和功能实现(基于STM32)

目录 概述 1 认识TCRT5000 1.1 模块介绍 1.2 电气特性 2 系统应用 2.1 系统架构 2.2 STM32Cube创建工程 3 功能实现 3.1 代码实现 3.2 源代码文件 4 功能测试 4.1 检测黑线状态 4.2 未检测黑线状态 概述 本文主要介绍TCRT5000模块的使用原理,包括该模块的硬件实现方式,电路实现原理,还使用STM32类