yolov8添加注意力机制模块-ShuffleAttention

2024-02-27 03:04

本文主要是介绍yolov8添加注意力机制模块-ShuffleAttention,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

修改

原本打算把ShuffleAttention模块先写进conv.py文件中,然后在引入tasks.py文件中。但是不知道咋回事,在tasks.py文件中引入报红。所以干脆直接把ShuffleAttention模块写进了tasks.py文件中。

from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return out

tasks.py文件中,在指定位置添加如下代码。在函数parse_model处。

        elif m is ShuffleAttention:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, *args[1:]]

修改yolov8.yaml文件。改动的地方为标红的地方。

# Ultralytics YOLO  , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [ 0.33, 0.25, 1024 ]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [ 0.33, 0.50, 1024 ]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [ 0.67, 0.75, 768 ]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [ 1.00, 1.00, 512 ]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [ 1.00, 1.25, 512 ]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [ -1, 1, Conv, [ 64, 3, 2 ] ]  # 0-P1/2- [ -1, 1, Conv, [ 128, 3, 2 ] ]  # 1-P2/4- [ -1, 3, C2f, [ 128, True ] ]- [ -1, 1, Conv, [ 256, 3, 2 ] ]  # 3-P3/8- [ -1, 6, C2f, [ 256, True ] ]- [ -1, 1, Conv, [ 512, 3, 2 ] ]  # 5-P4/16- [ -1, 6, C2f, [ 512, True ] ]- [ -1, 1, Conv, [ 1024, 3, 2 ] ]  # 7-P5/32- [ -1, 3, C2f, [ 1024, True ] ]- [ -1, 1, SPPF, [ 1024, 5 ] ]  # 9# YOLOv8.0n head
head:- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ [ -1, 6 ], 1, Concat, [ 1 ] ]  # cat backbone P4- [ -1, 3, C2f, [ 512 ] ]  # 12- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ [ -1, 4 ], 1, Concat, [ 1 ] ]  # cat backbone P3- [ -1, 3, C2f, [ 256 ] ]  # 15 (P3/8-small)- [ -1, 1, Conv, [ 256, 3, 2 ] ]- [ [ -1, 12 ], 1, Concat, [ 1 ] ]  # cat head P4- [ -1, 3, C2f, [ 512 ] ]  # 18 (P4/16-medium)- [ -1, 1, Conv, [ 512, 3, 2 ] ]- [ [ -1, 9 ], 1, Concat, [ 1 ] ]  # cat head P5- [ -1, 3, C2f, [ 1024 ] ]  # 21 (P5/32-large)- [ -1, 3, ShuffleAttention, [ 1024 ] ]- [ [ 15, 18, 22 ], 1, Detect, [ nc ] ]  # Detect(P3, P4, P5)

测试打印网络。

分析

下面对ShuffleAttention模块一部分一部分进行解读。

    def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()

  • self.G:存储组的数量。
  • self.channel:存储输入特征的通道数。
  • self.avg_pool:自适应平均池化层,输出大小为1x1,用于全局池化操作。
  • self.gn:分组归一化层。
  • self.cweightself.cbias:通道注意力的可学习权重和偏置。
  • self.sweightself.sbias:空间注意力的可学习权重和偏置。
  • self.sigmoid:Sigmoid激活函数。

注:分组归一化层self.gn,是对输入的x,按照通道分成几组,然后在每组里在分别进行归一化。经过这一层,会改变其中的值,但不改变形状。Parameter()方法,使用Parameter包装一个张量表示这个张量是模型中的一个参数,它会在模型的训练过程中被优化器更新。

    def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)

init_weights 方法是用于初始化神经网络模型中不同类型层的权重和偏置参数的函数。遍历模型中的所有模块,对于nn.Conv2d,nn.BatchNorm2d,nn.Linear分别进行不同的初始化操作。

    def channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return x

在一个卷积神经网络中对输入的4维张量(通常代表图像批次)的通道进行混洗。这种操作通常用于那些使用分组卷积的网络架构(例如ShuffleNet)中来提升模型性能,它通过在通道间进行信息的交换来增强特征的表达能力。

逐行进行解释:

  1. b, c, h, w = x.shape:这行代码获取输入张量 x 的形状,其中 b 是批次大小,c 是通道数,h 是特征图的高度,w 是特征图的宽度。

  2. x = x.reshape(b, groups, -1, h, w):这里,张量 x 被重新塑形(reshape)为一个新的形状。它首先按照批次大小 b,然后是分组数 groups 进行分割。-1 表示自动计算该维度的大小,具体来说,这里的 -1 表示每个分组的通道数(即 c // groups)。最后两维 h 和 w 保持不变。这一步准备将通道分为 groups 组,每组具有相等数量的通道。

  3. x = x.permute(0, 2, 1, 3, 4):permute 函数用于对张量的维度进行重新排列。这里,它将分组的维度(索引为1的维度)和通道的维度(索引为2的维度)调换位置。经过这一步操作后,张量的形状将变为 (b, -1, groups, h, w)

  4. x = x.reshape(b, -1, h, w):在上一步维度调换之后,此行代码再次将张量 x 进行重新塑形,使其变回原始的4维形状 (b, c, h, w),其中分组内的通道现在已经被混洗。由于维度调换的操作,原先属于同一组的通道现在分散到了不同的位置,从而完成了通道混洗的过程。

  5. return x:最后,返回了经过通道混洗后的张量 x

下面对forward中部分代码进行解析

x = x.view(b * self.G, -1, h, w)

这里将输入x的形状重构,将批次内的图像分成G组,每组的通道数变为原通道数除以G

x_0, x_1 = x.chunk(2, dim=1)  # 将输入x沿着通道维度分成两块

这里使用chunk函数沿着通道维度将输入x分成两部分,x_0x_1。每部分包含原本通道数的一半。

# channel attentionx_channel = self.avg_pool(x_0)  # 对x_0进行全局平均池化,输出维度变为(batch_size * G, new_channels, 1, 1)x_channel = self.cweight * x_channel + self.cbias  # 应用学习到的权重和偏置x_channel = x_0 * self.sigmoid(x_channel)  # 通过Sigmoid函数后与x_0相乘实现通道注意力

这是通道注意力机制的实现。首先,对x_0应用全局平均池化(avg_pool)来得到每个通道的全局特征,然后通过一个权重cweight和偏置cbias进行线性变换(这些可能是在类的初始化中定义的参数)。接着,通过Sigmoid函数激活这个通道特征图(x_channel),并将它与原始的x_0相乘,实现对不同通道的不同权重分配,即通道注意力。

# spatial attentionx_spatial = self.gn(x_1)  # 对x_1进行分组归一化x_spatial = self.sweight * x_spatial + self.sbias  # 应用学习到的权重和偏置x_spatial = x_1 * self.sigmoid(x_spatial)  # 通过Sigmoid函数后与x_1相乘实现空间注意力

这是空间注意力机制的实现。首先,对x_1应用分组归一化(gn),然后通过一个权重sweight和偏置sbias进行线性变换。最后,相同地,通过Sigmoid函数激活并与x_1相乘,实现对空间位置不同重要性的加权,即空间注意力。

# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # 将通道和空间注意力的结果在通道维度上拼接out = out.contiguous().view(b, -1, h, w)  # 重构形状为(batch_size, channels, height, width)

这里将通道注意力和空间注意力处理后的两部分沿通道维度拼接回一个完整的张量。然后改变其形状,以确保它与原始输入的批次大小、高度、宽度一致。

    # channel shuffleout = self.channel_shuffle(out, 2)return out

最后,进行通道混洗操作,改善跨通道的信息流通。

这篇关于yolov8添加注意力机制模块-ShuffleAttention的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751018

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,

如何通俗理解注意力机制?

1、注意力机制(Attention Mechanism)是机器学习和深度学习中一种模拟人类注意力的方法,用于提高模型在处理大量信息时的效率和效果。通俗地理解,它就像是在一堆信息中找到最重要的部分,把注意力集中在这些关键点上,从而更好地完成任务。以下是几个简单的比喻来帮助理解注意力机制: 2、寻找重点:想象一下,你在阅读一篇文章的时候,有些段落特别重要,你会特别注意这些段落,反复阅读,而对其他部分

【Tools】大模型中的注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 在大模型中,注意力机制是一种重要的技术,它被广泛应用于自然语言处理领域,特别是在机器翻译和语言模型中。 注意力机制的基本思想是通过计算输入序列中各个位置的权重,以确

Jenkins构建Maven聚合工程,指定构建子模块

一、设置单独编译构建子模块 配置: 1、Root POM指向父pom.xml 2、Goals and options指定构建模块的参数: mvn -pl project1/project1-son -am clean package 单独构建project1-son项目以及它所依赖的其它项目。 说明: mvn clean package -pl 父级模块名/子模块名 -am参数

寻迹模块TCRT5000的应用原理和功能实现(基于STM32)

目录 概述 1 认识TCRT5000 1.1 模块介绍 1.2 电气特性 2 系统应用 2.1 系统架构 2.2 STM32Cube创建工程 3 功能实现 3.1 代码实现 3.2 源代码文件 4 功能测试 4.1 检测黑线状态 4.2 未检测黑线状态 概述 本文主要介绍TCRT5000模块的使用原理,包括该模块的硬件实现方式,电路实现原理,还使用STM32类