YOLOv9改进策略 | 添加注意力篇 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)

本文主要是介绍YOLOv9改进策略 | 添加注意力篇 | TripletAttention三重注意力机制(附代码+机制原理+添加教程),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 一、本文介绍

本文给大家带来的改进是Triplet Attention三重注意力机制。这个机制,它通过三个不同的视角来分析输入的数据,就好比三个人从不同的角度来观察同一幅画,然后共同决定哪些部分最值得注意。三重注意力机制的主要思想是在网络中引入了一种新的注意力模块,这个模块包含三个分支,分别关注图像的不同维度。比如说,一个分支可能专注于图像的宽度,另一个分支专注于高度,第三个分支则聚焦于图像的深度,即色彩和纹理等特征。这样一来,网络就能够更全面地理解图像内容,就像是得到了一副三维眼镜,能够看到图片的立体效果一样。

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

目录

 一、本文介绍

二、Triplet Attention机制原理

2.1 Triplet Attention的基本原理 

2.2 Triplet Attention和其它简单注意力机制的对比 

2.3 Triplet Attention的实现流程

三、Triplet Attention的核心代码

四、手把手教你添加Triplet Attention

4.1 细节修改教程

4.1.1 修改一

​4.1.2 修改二

4.1.3 修改三 

4.1.4 修改四

4.2 Triplet Attention的yaml文件

4.2.1 Triplet Attention的yaml文件一

4.2.2 Triplet Attention的yaml文件二

4.3 Triplet Attention运行成功截图

五、本文总结 


二、Triplet Attention机制原理

论文地址:官方论文地址

代码地址:官方代码地址


2.1 Triplet Attention的基本原理 

三重注意力(Triplet Attention)的基本原理是利用三支结构捕获输入数据的跨维度交互,从而计算注意力权重。这个方法能够构建输入通道或空间位置之间的相互依赖性,而且计算代价小。三重注意力由三个分支组成,每个分支负责捕获空间维度H或W与通道维度C之间的交互特征。通过对每个分支中的输入张量进行排列变换,然后通过Z池操作和一个大小为k×k的卷积层,生成注意力权重。这些权重是通过一个S形激活层生成的,然后应用于排列变换后的输入张量,再变换回原来的输入形状 

三重注意力(Triplet Attention)的主要改进点包括:

  1. 跨维度的注意力权重计算: 通过一个创新的三支结构捕获通道、高度、宽度三个维度之间的交互关系来计算注意力权重。

  2. 旋转操作和残差变换: 通过旋转输入张量和应用残差变换来建立不同维度间的依赖,这是三重注意力机制中的关键步骤。

  3. 维度间依赖性的重要性: 强调在计算注意力权重时,捕获跨维度依赖性的重要性,这是三重注意力的核心直觉和设计理念。

下面的图片是三重注意力的一个抽象表示图,展示了三个分支如何捕获跨维度交互。图中的每个子图表示三重注意力中的一个分支: 

1. 分支(a): 这个分支直接处理输入张量,没有进行旋转,然后通过残差变换来提取特征。

2. 分支(b): 这个分支首先沿着宽度(W)和通道(C)的维度旋转输入张量,然后进行残差变换。

3. 分支(c): 这个分支沿着高度(H)和通道(C)的维度旋转输入张量,之后同样进行残差变换。

总结:通过这样的设计,三重注意力模型能够有效地捕获输入张量中的空间和通道维度之间的交互关系。这种方法使模型能够构建通道与空间位置之间的相互依赖性,提高模型对特征的理解能力。


2.2 Triplet Attention和其它简单注意力机制的对比 

下面的图片是论文中三重注意力机制和其它注意力机制的一个对比大家有兴趣可以看看,横向扩展以下自己的知识库。

这张图片是一幅对比不同注意力模块的图示,其中包括:

1.Squeeze Excitation (SE) Module:
这个模块使用全局平均池化 (Global Avg Pool) 生成通道描述符,接着通过两个全连接层(1x1 Conv),中间使用ReLU激活函数,最后通过Sigmoid函数生成每个通道的权重。

2. Convolutional Block Attention Module (CBAM):
首先使用全局平均池化和全局最大池化(GAP + GMP)结合,再通过一个卷积层和ReLU激活函数,最后经过另一个卷积层和Sigmoid函数生成注意力权重。

3. Global Context (GC) Module:
从一个1x1卷积层开始,经过Softmax函数进行归一化,接着进行另一个1x1卷积,然后使用LayerNorm和最终的1x1卷积,通过广播加法结合原始特征图。

4. Triplet Attention (我们的方法):
分为三个分支,每个分支进行不同的处理:通道池化后的7x7卷积,Z池化,再接一个7x7卷积,然后是批量归一化和Sigmoid函数。每个分支都有一个Permute操作来调整维度。最后,三个分支的结果通过平均池化聚合起来生成最终的注意力权重。

每种模块都设计用于处理特征图(C x H x W),其中C是通道数,H是高度,W是宽度。这些模块通过不同方式计算注意力权重,增强网络对特征的重要部分的关注度,从而在各种视觉任务中提高性能。图片中的符号⊗代表矩阵乘法,⊕代表广播元素级加法。


2.3 Triplet Attention的实现流程

下面的图片是三重注意力(Triplet Attention)的具体实现流程图。图中详细展示了三个分支如何处理输入张量,并最终合成三重注意力。下面是对这个过程的描述: 

  1. 上部分支: 负责计算通道维度C和空间维度W的注意力权重。这个分支对输入张量进行Z池化(Z-Pool)操作,然后通过一个卷积层(Conv),接着用Sigmoid函数生成注意力权重。

  2. 中部分支: 负责捕获通道维度C与空间维度H和W之间的依赖性。这个分支首先进行相同的Z池化和卷积操作,然后同样通过Sigmoid函数生成注意力权重。

  3. 下部分支: 用于捕获空间维度之间的依赖性。这个分支保持输入的身份(Identity,即不改变输入),执行Z池化和卷积操作,之后也通过Sigmoid函数生成注意力权重。

每个分支在生成注意力权重后,会对输入进行排列(Permutation),然后将三个分支的输出进行平均聚合(Avg),最终得到三重注意力输出。

这种结构通过不同的旋转和排列操作,能够综合不同维度上的信息,更好地捕获数据的内在特征,同时这种方法在计算上是高效的,并且可以作为一个模块加入到现有的网络架构中,增强网络对复杂数据结构的理解和处理能力。


三、Triplet Attention的核心代码

使用方式看章节四!

import torch
import torch.nn as nn__all__ = ['TripletAttention']def autopad(k, p=None, d=1):  # kernel, padding, dilation# Pad to 'same' shape outputsif d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn pclass Conv(nn.Module):# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):return self.act(self.conv(x))class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass ZPool(nn.Module):def forward(self, x):return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)class AttentionGate(nn.Module):def __init__(self):super(AttentionGate, self).__init__()kernel_size = 7self.compress = ZPool()self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.conv(x_compress)scale = torch.sigmoid_(x_out)return x * scaleclass TripletAttention(nn.Module):def __init__(self, no_spatial=False):super(TripletAttention, self).__init__()self.cw = AttentionGate()self.hc = AttentionGate()self.no_spatial = no_spatialif not no_spatial:self.hw = AttentionGate()def forward(self, x):x_perm1 = x.permute(0, 2, 1, 3).contiguous()x_out1 = self.cw(x_perm1)x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()x_perm2 = x.permute(0, 3, 2, 1).contiguous()x_out2 = self.hc(x_perm2)x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()if not self.no_spatial:x_out = self.hw(x)x_out = 1 / 3 * (x_out + x_out11 + x_out21)else:x_out = 1 / 2 * (x_out11 + x_out21)return x_out


四、手把手教你添加Triplet Attention

4.1 细节修改教程

4.1.1 修改一

我们找到如下的目录'yolov9-main/models'在这个目录下创建一个文件目录(注意是目录,因为我这个专栏会出很多的更新,这里用一种一劳永逸的方法)文件目录起名modules,然后在下面新建一个文件,将我们的代码复制粘贴进去。


​4.1.2 修改二

然后新建一个__init__.py文件,然后我们在里面添加一行代码。注意标记一个'.'其作用是标记当前目录。


4.1.3 修改三 

然后我们找到如下文件''models/yolo.py''在开头的地方导入我们的模块按照如下修改->

(如果你看了我多个改进机制此处只需要添加一个即可,无需重复添加)

​​​​


4.1.4 修改四

然后我们找到parse_model方法,按照如下修改->

        elif m in {TripletAttention}:c2 = ch[f]args = [c2, *args]

到此就修改完成了,复制下面的ymal文件即可运行。


4.2 Triplet Attention的yaml文件

4.2.1 Triplet Attention的yaml文件一

下面的配置文件为我修改的Triplet Attention的位置,参数的位置里面什么都不用添加空着就行,大家复制粘贴我的就可以运行,同时我提供多个版本给大家,根据我的经验可能涨点的位置。

# YOLOv9# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()# anchors
anchors: 3# YOLOv9 backbone
backbone:[[-1, 1, Silence, []],# conv down[-1, 1, Conv, [64, 3, 2]],  # 1-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 2-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3# conv down[-1, 1, Conv, [256, 3, 2]],  # 4-P3/8# elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5# conv down[-1, 1, Conv, [512, 3, 2]],  # 6-P4/16# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7# conv down[-1, 1, Conv, [512, 3, 2]],  # 8-P5/32# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9]# YOLOv9 head
head:[[-1, 1, TripletAttention, []],  # 添加一行我们的改进机制# elan-spp block[-1, 1, SPPELAN, [512, 256]],  # 11# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 14# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3# elan-2 block[-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 17 (P3/8-small)# conv-down merge[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 20 (P4/16-medium)# conv-down merge[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 23 (P5/32-large)# routing[5, 1, CBLinear, [[256]]], # 24[7, 1, CBLinear, [[256, 512]]], # 25[9, 1, CBLinear, [[256, 512, 512]]], # 26# conv down[0, 1, Conv, [64, 3, 2]],  # 27-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 28-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 29# conv down fuse[-1, 1, Conv, [256, 3, 2]],  # 30-P3/8[[24, 25, 26, -1], 1, CBFuse, [[0, 0, 0]]], # 31# elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 32# conv down fuse[-1, 1, Conv, [512, 3, 2]],  # 33-P4/16[[25, 26, -1], 1, CBFuse, [[1, 1]]], # 34# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 35# conv down fuse[-1, 1, Conv, [512, 3, 2]],  # 36-P5/32[[26, -1], 1, CBFuse, [[2]]], # 37# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 38# detect[[32, 35, 38, 17, 20, 23], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)]


4.2.2 Triplet Attention的yaml文件二

# YOLOv9# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()# anchors
anchors: 3# YOLOv9 backbone
backbone:[[-1, 1, Silence, []],# conv down[-1, 1, Conv, [64, 3, 2]],  # 1-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 2-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3# conv down[-1, 1, Conv, [256, 3, 2]],  # 4-P3/8# elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5# conv down[-1, 1, Conv, [512, 3, 2]],  # 6-P4/16# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7# conv down[-1, 1, Conv, [512, 3, 2]],  # 8-P5/32# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9]# YOLOv9 head
head:[# elan-spp block[-1, 1, SPPELAN, [512, 256]],  # 10# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3# elan-2 block[-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)[-1, 1, TripletAttention, []],  # 17 添加一行我们的改进机制# conv-down merge[-1, 1, Conv, [256, 3, 2]],[[-1, 13], 1, Concat, [1]],  # cat head P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 20 (P4/16-medium)[-1, 1, TripletAttention, []],  # 21 添加一行我们的改进机制# conv-down merge[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 24 (P5/32-large)[-1, 1, TripletAttention, []],  # 25 添加一行我们的改进机制# routing[5, 1, CBLinear, [[256]]], # 26[7, 1, CBLinear, [[256, 512]]], # 27[9, 1, CBLinear, [[256, 512, 512]]], # 28# conv down[0, 1, Conv, [64, 3, 2]],  # 29-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 30-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 31# conv down fuse[-1, 1, Conv, [256, 3, 2]],  # 32-P3/8[[26, 27, 28, -1], 1, CBFuse, [[0, 0, 0]]], # 33# elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 34[-1, 1, TripletAttention, []],  # 35 添加一行我们的改进机制# conv down fuse[-1, 1, Conv, [512, 3, 2]],  # 36-P4/16[[27, 28, -1], 1, CBFuse, [[1, 1]]], # 37# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 38[-1, 1, TripletAttention, []],  # 39 添加一行我们的改进机制# conv down fuse[-1, 1, Conv, [512, 3, 2]],  # 40-P5/32[[28, -1], 1, CBFuse, [[2]]], # 41# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 42[-1, 1, TripletAttention, []],  # 43 添加一行我们的改进机制# detect[[35, 39, 43, 17, 21, 25], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)]

4.3 Triplet Attention运行成功截图

附上我的运行记录确保我的教程是可用的。 


五、本文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv9改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

这篇关于YOLOv9改进策略 | 添加注意力篇 | TripletAttention三重注意力机制(附代码+机制原理+添加教程)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/933275

相关文章

uniapp接入微信小程序原生代码配置方案(优化版)

uniapp项目需要把微信小程序原生语法的功能代码嵌套过来,无需把原生代码转换为uniapp,可以配置拷贝的方式集成过来 1、拷贝代码包到src目录 2、vue.config.js中配置原生代码包直接拷贝到编译目录中 3、pages.json中配置分包目录,原生入口组件的路径 4、manifest.json中配置分包,使用原生组件 5、需要把原生代码包里的页面修改成组件的方

公共筛选组件(二次封装antd)支持代码提示

如果项目是基于antd组件库为基础搭建,可使用此公共筛选组件 使用到的库 npm i antdnpm i lodash-esnpm i @types/lodash-es -D /components/CommonSearch index.tsx import React from 'react';import { Button, Card, Form } from 'antd'

17.用300行代码手写初体验Spring V1.0版本

1.1.课程目标 1、了解看源码最有效的方式,先猜测后验证,不要一开始就去调试代码。 2、浓缩就是精华,用 300行最简洁的代码 提炼Spring的基本设计思想。 3、掌握Spring框架的基本脉络。 1.2.内容定位 1、 具有1年以上的SpringMVC使用经验。 2、 希望深入了解Spring源码的人群,对 Spring有一个整体的宏观感受。 3、 全程手写实现SpringM

(超详细)YOLOV7改进-Soft-NMS(支持多种IoU变种选择)

1.在until/general.py文件最后加上下面代码 2.在general.py里面找到这代码,修改这两个地方 3.之后直接运行即可

YOLOv8改进 | SPPF | 具有多尺度带孔卷积层的ASPP【CVPR2018】

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡 专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有40+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转 Atrous Spatial Pyramid Pooling (ASPP) 是一种在深度学习框架中用于语义分割的网络结构,它旨

代码随想录算法训练营:12/60

非科班学习算法day12 | LeetCode150:逆波兰表达式 ,Leetcode239: 滑动窗口最大值  目录 介绍 一、基础概念补充: 1.c++字符串转为数字 1. std::stoi, std::stol, std::stoll, std::stoul, std::stoull(最常用) 2. std::stringstream 3. std::atoi, std

记录AS混淆代码模板

开启混淆得先在build.gradle文件中把 minifyEnabled false改成true,以及shrinkResources true//去除无用的resource文件 这些是写在proguard-rules.pro文件内的 指定代码的压缩级别 -optimizationpasses 5 包明不混合大小写 -dontusemixedcaseclassnames 不去忽略非公共

Linux系统稳定性的奥秘:探究其背后的机制与哲学

在计算机操作系统的世界里,Linux以其卓越的稳定性和可靠性著称,成为服务器、嵌入式系统乃至个人电脑用户的首选。那么,是什么造就了Linux如此之高的稳定性呢?本文将深入解析Linux系统稳定性的几个关键因素,揭示其背后的技术哲学与实践。 1. 开源协作的力量Linux是一个开源项目,意味着任何人都可以查看、修改和贡献其源代码。这种开放性吸引了全球成千上万的开发者参与到内核的维护与优化中,形成了

Steam邮件推送内容有哪些?配置教程详解!

Steam邮件推送功能是否安全?如何个性化邮件推送内容? Steam作为全球最大的数字游戏分发平台之一,不仅提供了海量的游戏资源,还通过邮件推送为用户提供最新的游戏信息、促销活动和个性化推荐。AokSend将详细介绍Steam邮件推送的主要内容。 Steam邮件推送:促销优惠 每当平台举办大型促销活动,如夏季促销、冬季促销、黑色星期五等,用户都会收到邮件通知。这些邮件详细列出了打折游戏、

麻了!一觉醒来,代码全挂了。。

作为⼀名程序员,相信大家平时都有代码托管的需求。 相信有不少同学或者团队都习惯把自己的代码托管到GitHub平台上。 但是GitHub大家知道,经常在访问速度这方面并不是很快,有时候因为网络问题甚至根本连网站都打不开了,所以导致使用体验并不友好。 经常一觉醒来,居然发现我竟然看不到我自己上传的代码了。。 那在国内,除了GitHub,另外还有一个比较常用的Gitee平台也可以用于