YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】

本文主要是介绍YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

得益于在通道或空间位置之间建立相互依赖关系的能力,近年来,注意力机制在计算机视觉任务中得到了广泛的研究和应用。一种轻量级但有效的注意力机制——三重注意力,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的创新方法。对于一个输入张量,三重注意力通过旋转变换建立跨维度依赖关系,并通过残差变换编码跨通道和空间信息,几乎不增加计算开销。在本文中,给大家带来的教程是将原来的网络添加TripletAttention。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6.总结


1.原理

官方论文:Rotate to Attend: Convolutional Triplet Attention Module——点击即可跳转

官方代码:官方代码仓库地址——点击即可跳转

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,旨在提高模型对输入数据的理解和表示能力。它在自然语言处理(NLP)和计算机视觉(CV)等领域都有应用。

这个机制的核心思想是将注意力机制引入到不同级别的特征表示中,以更全面地捕捉输入数据的信息。通常来说,传统的注意力机制会在同一级别的特征表示中计算注意力权重,而三重注意力机制则引入了三个不同级别的特征表示,并在每个级别上计算注意力权重,从而实现了“三重”的概念。

具体来说,三重注意力机制通常包含以下三个层次的注意力计算:

  1. 全局注意力(Global Attention): 全局注意力通常是在输入数据的最底层或最原始的表示上计算的,例如,在NLP中可能是词级别的表示,或者在CV中可能是原始图像的表示。在这一层次上,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  2. 组间注意力(Inter-group Attention): 组间注意力是在全局注意力得到的表示的基础上计算的。它将全局表示分成不同的组(可能是空间上的不同区域,或者是语义上的不同部分),然后在这些组之间计算注意力权重。这一层级的注意力有助于模型更好地理解输入数据中不同部分之间的关系和交互。

  3. 组内注意力(Intra-group Attention): 组内注意力是在组间注意力得到的表示的基础上计算的。它在每个组内部计算注意力权重,以捕捉组内部分的重要性和关联性。这一层级的注意力有助于模型更好地理解每个组内部分的内在结构和语义信息。

通过这三个层次的注意力计算,三重注意力机制可以在不同级别上捕捉输入数据的全局信息、组间关系和组内结构,从而更有效地理解和表示输入数据。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,能够更全面地捕捉输入数据的信息,从而提高了深度学习模型的表现能力。

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

关键步骤一将下面代码粘贴到/projects/yolov5-6.1/models/common.py文件中

import torch
import math
import torch.nn as nn
import torch.nn.functional as Fclass BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass ChannelPool(nn.Module):def forward(self, x):return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )class SpatialGate(nn.Module):def __init__(self):super(SpatialGate, self).__init__()kernel_size = 7self.compress = ChannelPool()self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.spatial(x_compress)scale = torch.sigmoid_(x_out) return x * scaleclass TripletAttention(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):super(TripletAttention, self).__init__()self.ChannelGateH = SpatialGate()self.ChannelGateW = SpatialGate()self.no_spatial=no_spatialif not no_spatial:self.SpatialGate = SpatialGate()def forward(self, x):x_perm1 = x.permute(0,2,1,3).contiguous()x_out1 = self.ChannelGateH(x_perm1)x_out11 = x_out1.permute(0,2,1,3).contiguous()x_perm2 = x.permute(0,3,2,1).contiguous()x_out2 = self.ChannelGateW(x_perm2)x_out21 = x_out2.permute(0,3,2,1).contiguous()if not self.no_spatial:x_out = self.SpatialGate(x)x_out = (1/3)*(x_out + x_out11 + x_out21)else:x_out = (1/2)*(x_out11 + x_out21)return x_out

三重注意力机制的主要流程可以分为以下步骤:

  1. 输入数据表示: 首先,将输入数据(例如文本序列、图像等)进行表示。这可能包括将文本序列转换为词嵌入向量、将图像转换为特征图等。这一步骤的目的是将输入数据转换为模型可以处理的表示形式。

  2. 全局注意力计算: 在第一级别,对输入数据的全局表示进行计算。这可以通过应用传统的注意力机制来实现,例如使用自注意力机制(Self-Attention)或注意力机制的变体。在这一步骤中,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  3. 组间表示生成: 在第二级别,根据全局注意力得到的权重,将全局表示分成不同的组。这些组可以根据具体的任务和数据特点来确定,例如在图像中可能是空间上的不同区域,在文本中可能是不同的句子或段落。然后,对每个组进行表示生成,得到组间表示。

  4. 组间注意力计算: 在第二级别,对组间表示进行注意力计算。这一步骤可以类似地使用注意力机制,但是针对的是组间的关系和交互。通过计算组间的注意力权重,模型可以更好地理解不同组之间的关系和重要性。

  5. 组内表示生成: 在第三级别,根据组间注意力得到的权重,将每个组内的表示进行生成。这一步骤可以帮助模型更好地理解每个组内部分的内在结构和语义信息。

  6. 组内注意力计算: 在第三级别,对每个组内的表示进行注意力计算。这类似于组间注意力计算,但是针对的是组内部分的关系和重要性。通过计算组内的注意力权重,模型可以更好地理解组内部分之间的关系和重要性。

  7. 输出: 最后,根据经过三级注意力机制处理后的表示,进行任务相关的后续处理,如分类、回归等,得到最终的输出结果。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,实现了对输入数据的全局信息、组间关系和组内结构的捕捉和理解,从而提高了深度学习模型的表现能力。

2.2 新增yaml文件

关键步骤二在下/projects/yolov5-6.1/models下新建文件 yolov5_TripletAttention.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, TripletAttention, [128,3]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 15], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

温馨提示:本文只是对yolov5l基础上添加swin模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

关键步骤三在yolo.py中注册, 大概在260行左右添加 ‘TripletAttention’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_TripletAttention.yaml的路径

建议大家写绝对路径,确保一定能找到

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1RST9hL8La0GZ8n-kk9bXiw?pwd=45cq

提取码: 45cq 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs,可以看出这个计算量几乎没有变化,也印证了文章开头说的“计算量几乎b”

5. 进阶

你能在不同的位置添加三重注意力机制吗?这非常有趣,快去试试吧

6.总结

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,通过在不同层次上引入注意力机制,增强模型对输入数据的理解和表示能力。其流程包括首先将输入数据转换为模型可处理的表示形式,然后在全局表示上计算注意力权重以捕捉整体上下文信息,接着根据全局注意力权重将全局表示分成不同的组并生成组间表示,再在组间表示上计算注意力权重以理解组间关系,随后生成组内表示并在组内计算注意力权重以捕捉组内结构和关系,最后基于处理后的表示进行任务相关的处理得到最终输出。通过全局、组间和组内三个层次的注意力计算,三重注意力机制能够更全面地捕捉输入数据的信息,从而提升模型的表现能力。

这篇关于YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1020326

相关文章

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

Python中随机休眠技术原理与应用详解

《Python中随机休眠技术原理与应用详解》在编程中,让程序暂停执行特定时间是常见需求,当需要引入不确定性时,随机休眠就成为关键技巧,下面我们就来看看Python中随机休眠技术的具体实现与应用吧... 目录引言一、实现原理与基础方法1.1 核心函数解析1.2 基础实现模板1.3 整数版实现二、典型应用场景2

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

java中反射(Reflection)机制举例详解

《java中反射(Reflection)机制举例详解》Java中的反射机制是指Java程序在运行期间可以获取到一个对象的全部信息,:本文主要介绍java中反射(Reflection)机制的相关资料... 目录一、什么是反射?二、反射的用途三、获取Class对象四、Class类型的对象使用场景1五、Class

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

python+opencv处理颜色之将目标颜色转换实例代码

《python+opencv处理颜色之将目标颜色转换实例代码》OpenCV是一个的跨平台计算机视觉库,可以运行在Linux、Windows和MacOS操作系统上,:本文主要介绍python+ope... 目录下面是代码+ 效果 + 解释转HSV: 关于颜色总是要转HSV的掩膜再标注总结 目标:将红色的部分滤

在C#中调用Python代码的两种实现方式

《在C#中调用Python代码的两种实现方式》:本文主要介绍在C#中调用Python代码的两种实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录C#调用python代码的方式1. 使用 Python.NET2. 使用外部进程调用 Python 脚本总结C#调

Java时间轮调度算法的代码实现

《Java时间轮调度算法的代码实现》时间轮是一种高效的定时调度算法,主要用于管理延时任务或周期性任务,它通过一个环形数组(时间轮)和指针来实现,将大量定时任务分摊到固定的时间槽中,极大地降低了时间复杂... 目录1、简述2、时间轮的原理3. 时间轮的实现步骤3.1 定义时间槽3.2 定义时间轮3.3 使用时

Java中&和&&以及|和||的区别、应用场景和代码示例

《Java中&和&&以及|和||的区别、应用场景和代码示例》:本文主要介绍Java中的逻辑运算符&、&&、|和||的区别,包括它们在布尔和整数类型上的应用,文中通过代码介绍的非常详细,需要的朋友可... 目录前言1. & 和 &&代码示例2. | 和 ||代码示例3. 为什么要使用 & 和 | 而不是总是使