YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】

本文主要是介绍YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

得益于在通道或空间位置之间建立相互依赖关系的能力,近年来,注意力机制在计算机视觉任务中得到了广泛的研究和应用。一种轻量级但有效的注意力机制——三重注意力,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的创新方法。对于一个输入张量,三重注意力通过旋转变换建立跨维度依赖关系,并通过残差变换编码跨通道和空间信息,几乎不增加计算开销。在本文中,给大家带来的教程是将原来的网络添加TripletAttention。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6.总结


1.原理

官方论文:Rotate to Attend: Convolutional Triplet Attention Module——点击即可跳转

官方代码:官方代码仓库地址——点击即可跳转

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,旨在提高模型对输入数据的理解和表示能力。它在自然语言处理(NLP)和计算机视觉(CV)等领域都有应用。

这个机制的核心思想是将注意力机制引入到不同级别的特征表示中,以更全面地捕捉输入数据的信息。通常来说,传统的注意力机制会在同一级别的特征表示中计算注意力权重,而三重注意力机制则引入了三个不同级别的特征表示,并在每个级别上计算注意力权重,从而实现了“三重”的概念。

具体来说,三重注意力机制通常包含以下三个层次的注意力计算:

  1. 全局注意力(Global Attention): 全局注意力通常是在输入数据的最底层或最原始的表示上计算的,例如,在NLP中可能是词级别的表示,或者在CV中可能是原始图像的表示。在这一层次上,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  2. 组间注意力(Inter-group Attention): 组间注意力是在全局注意力得到的表示的基础上计算的。它将全局表示分成不同的组(可能是空间上的不同区域,或者是语义上的不同部分),然后在这些组之间计算注意力权重。这一层级的注意力有助于模型更好地理解输入数据中不同部分之间的关系和交互。

  3. 组内注意力(Intra-group Attention): 组内注意力是在组间注意力得到的表示的基础上计算的。它在每个组内部计算注意力权重,以捕捉组内部分的重要性和关联性。这一层级的注意力有助于模型更好地理解每个组内部分的内在结构和语义信息。

通过这三个层次的注意力计算,三重注意力机制可以在不同级别上捕捉输入数据的全局信息、组间关系和组内结构,从而更有效地理解和表示输入数据。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,能够更全面地捕捉输入数据的信息,从而提高了深度学习模型的表现能力。

2. TripletAttention代码实现

2.1 将TripletAttention添加到YOLOv5中

关键步骤一将下面代码粘贴到/projects/yolov5-6.1/models/common.py文件中

import torch
import math
import torch.nn as nn
import torch.nn.functional as Fclass BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass ChannelPool(nn.Module):def forward(self, x):return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )class SpatialGate(nn.Module):def __init__(self):super(SpatialGate, self).__init__()kernel_size = 7self.compress = ChannelPool()self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.spatial(x_compress)scale = torch.sigmoid_(x_out) return x * scaleclass TripletAttention(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):super(TripletAttention, self).__init__()self.ChannelGateH = SpatialGate()self.ChannelGateW = SpatialGate()self.no_spatial=no_spatialif not no_spatial:self.SpatialGate = SpatialGate()def forward(self, x):x_perm1 = x.permute(0,2,1,3).contiguous()x_out1 = self.ChannelGateH(x_perm1)x_out11 = x_out1.permute(0,2,1,3).contiguous()x_perm2 = x.permute(0,3,2,1).contiguous()x_out2 = self.ChannelGateW(x_perm2)x_out21 = x_out2.permute(0,3,2,1).contiguous()if not self.no_spatial:x_out = self.SpatialGate(x)x_out = (1/3)*(x_out + x_out11 + x_out21)else:x_out = (1/2)*(x_out11 + x_out21)return x_out

三重注意力机制的主要流程可以分为以下步骤:

  1. 输入数据表示: 首先,将输入数据(例如文本序列、图像等)进行表示。这可能包括将文本序列转换为词嵌入向量、将图像转换为特征图等。这一步骤的目的是将输入数据转换为模型可以处理的表示形式。

  2. 全局注意力计算: 在第一级别,对输入数据的全局表示进行计算。这可以通过应用传统的注意力机制来实现,例如使用自注意力机制(Self-Attention)或注意力机制的变体。在这一步骤中,模型尝试理解整个输入的上下文信息,并计算每个部分的重要性。

  3. 组间表示生成: 在第二级别,根据全局注意力得到的权重,将全局表示分成不同的组。这些组可以根据具体的任务和数据特点来确定,例如在图像中可能是空间上的不同区域,在文本中可能是不同的句子或段落。然后,对每个组进行表示生成,得到组间表示。

  4. 组间注意力计算: 在第二级别,对组间表示进行注意力计算。这一步骤可以类似地使用注意力机制,但是针对的是组间的关系和交互。通过计算组间的注意力权重,模型可以更好地理解不同组之间的关系和重要性。

  5. 组内表示生成: 在第三级别,根据组间注意力得到的权重,将每个组内的表示进行生成。这一步骤可以帮助模型更好地理解每个组内部分的内在结构和语义信息。

  6. 组内注意力计算: 在第三级别,对每个组内的表示进行注意力计算。这类似于组间注意力计算,但是针对的是组内部分的关系和重要性。通过计算组内的注意力权重,模型可以更好地理解组内部分之间的关系和重要性。

  7. 输出: 最后,根据经过三级注意力机制处理后的表示,进行任务相关的后续处理,如分类、回归等,得到最终的输出结果。

总的来说,三重注意力机制通过在不同级别上引入注意力机制,实现了对输入数据的全局信息、组间关系和组内结构的捕捉和理解,从而提高了深度学习模型的表现能力。

2.2 新增yaml文件

关键步骤二在下/projects/yolov5-6.1/models下新建文件 yolov5_TripletAttention.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, TripletAttention, [128,3]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 15], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

温馨提示:本文只是对yolov5l基础上添加swin模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

关键步骤三在yolo.py中注册, 大概在260行左右添加 ‘TripletAttention’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_TripletAttention.yaml的路径

建议大家写绝对路径,确保一定能找到

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1RST9hL8La0GZ8n-kk9bXiw?pwd=45cq

提取码: 45cq 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs,可以看出这个计算量几乎没有变化,也印证了文章开头说的“计算量几乎b”

5. 进阶

你能在不同的位置添加三重注意力机制吗?这非常有趣,快去试试吧

6.总结

三重注意力机制(Triplet Attention)是一种深度学习中的注意力机制,通过在不同层次上引入注意力机制,增强模型对输入数据的理解和表示能力。其流程包括首先将输入数据转换为模型可处理的表示形式,然后在全局表示上计算注意力权重以捕捉整体上下文信息,接着根据全局注意力权重将全局表示分成不同的组并生成组间表示,再在组间表示上计算注意力权重以理解组间关系,随后生成组内表示并在组内计算注意力权重以捕捉组内结构和关系,最后基于处理后的表示进行任务相关的处理得到最终输出。通过全局、组间和组内三个层次的注意力计算,三重注意力机制能够更全面地捕捉输入数据的信息,从而提升模型的表现能力。

这篇关于YOLOv5改进 | 注意力机制 | 添加三重注意力机制 TripletAttention【原理 + 完整代码】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1020326

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

Python中顺序结构和循环结构示例代码

《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(

PyCharm 接入 DeepSeek最新完整教程

《PyCharm接入DeepSeek最新完整教程》文章介绍了DeepSeek-V3模型的性能提升以及如何在PyCharm中接入和使用DeepSeek进行代码开发,本文通过图文并茂的形式给大家介绍的... 目录DeepSeek-V3效果演示创建API Key在PyCharm中下载Continue插件配置Con

Spring排序机制之接口与注解的使用方法

《Spring排序机制之接口与注解的使用方法》本文介绍了Spring中多种排序机制,包括Ordered接口、PriorityOrdered接口、@Order注解和@Priority注解,提供了详细示例... 目录一、Spring 排序的需求场景二、Spring 中的排序机制1、Ordered 接口2、Pri

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

CSS3中使用flex和grid实现等高元素布局的示例代码

《CSS3中使用flex和grid实现等高元素布局的示例代码》:本文主要介绍了使用CSS3中的Flexbox和Grid布局实现等高元素布局的方法,通过简单的两列实现、每行放置3列以及全部代码的展示,展示了这两种布局方式的实现细节和效果,详细内容请阅读本文,希望能对你有所帮助... 过往的实现方法是使用浮动加

JAVA调用Deepseek的api完成基本对话简单代码示例

《JAVA调用Deepseek的api完成基本对话简单代码示例》:本文主要介绍JAVA调用Deepseek的api完成基本对话的相关资料,文中详细讲解了如何获取DeepSeekAPI密钥、添加H... 获取API密钥首先,从DeepSeek平台获取API密钥,用于身份验证。添加HTTP客户端依赖使用Jav

Java实现状态模式的示例代码

《Java实现状态模式的示例代码》状态模式是一种行为型设计模式,允许对象根据其内部状态改变行为,本文主要介绍了Java实现状态模式的示例代码,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来... 目录一、简介1、定义2、状态模式的结构二、Java实现案例1、电灯开关状态案例2、番茄工作法状态案例

本地搭建DeepSeek-R1、WebUI的完整过程及访问

《本地搭建DeepSeek-R1、WebUI的完整过程及访问》:本文主要介绍本地搭建DeepSeek-R1、WebUI的完整过程及访问的相关资料,DeepSeek-R1是一个开源的人工智能平台,主... 目录背景       搭建准备基础概念搭建过程访问对话测试总结背景       最近几年,人工智能技术