神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解

本文主要是介绍神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解

  • 学习前言
  • 代码下载
  • CA注意力机制的概念与实现
  • 注意力机制的应用

学习前言

CA注意力机制是最近提出的一种注意力机制,全面关注特征层的空间信息和通道信息。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-pytorch

复制该路径到地址栏跳转。

CA注意力机制的概念与实现

在这里插入图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H]和[C, 1, W],之后进行转置。获得两个特征层[C, H, 1]和[C, 1, W]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制。

实现的python代码为:

class CA_Block(nn.Module):def __init__(self, channel, reduction=16):super(CA_Block, self).__init__()self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel//reduction, kernel_size=1, stride=1, bias=False)self.relu   = nn.ReLU()self.bn     = nn.BatchNorm2d(channel//reduction)self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)self.sigmoid_h = nn.Sigmoid()self.sigmoid_w = nn.Sigmoid()def forward(self, x):_, _, h, w = x.size()x_h = torch.mean(x, dim = 3, keepdim = True).permute(0, 1, 3, 2)x_w = torch.mean(x, dim = 2, keepdim = True)x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))out = x * s_h.expand_as(x) * s_w.expand_as(x)return out

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention_block = [se_block, cbam_block, eca_block, CA_Block]#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
class YoloBody(nn.Module):def __init__(self, anchors_mask, num_classes, phi=0):super(YoloBody, self).__init__()self.phi            = phiself.backbone       = darknet53_tiny(None)self.conv_for_P5    = BasicConv(512,256,1)self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)self.upsample       = Upsample(256,128)self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)if 1 <= self.phi and self.phi <= 3:self.feat1_att      = attention_block[self.phi - 1](256)self.feat2_att      = attention_block[self.phi - 1](512)self.upsample_att   = attention_block[self.phi - 1](128)def forward(self, x):#---------------------------------------------------##   生成CSPdarknet53_tiny的主干模型#   feat1的shape为26,26,256#   feat2的shape为13,13,512#---------------------------------------------------#feat1, feat2 = self.backbone(x)if 1 <= self.phi and self.phi <= 3:feat1 = self.feat1_att(feat1)feat2 = self.feat2_att(feat2)# 13,13,512 -> 13,13,256P5 = self.conv_for_P5(feat2)# 13,13,256 -> 13,13,512 -> 13,13,255out0 = self.yolo_headP5(P5) # 13,13,256 -> 13,13,128 -> 26,26,128P5_Upsample = self.upsample(P5)# 26,26,256 + 26,26,128 -> 26,26,384if 1 <= self.phi and self.phi <= 3:P5_Upsample = self.upsample_att(P5_Upsample)P4 = torch.cat([P5_Upsample,feat1],axis=1)# 26,26,384 -> 26,26,256 -> 26,26,255out1 = self.yolo_headP4(P4)return out0, out1

这篇关于神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/214050

相关文章

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux中shell解析脚本的通配符、元字符、转义符说明

《Linux中shell解析脚本的通配符、元字符、转义符说明》:本文主要介绍shell通配符、元字符、转义符以及shell解析脚本的过程,通配符用于路径扩展,元字符用于多命令分割,转义符用于将特殊... 目录一、linux shell通配符(wildcard)二、shell元字符(特殊字符 Meta)三、s

SQL注入漏洞扫描之sqlmap详解

《SQL注入漏洞扫描之sqlmap详解》SQLMap是一款自动执行SQL注入的审计工具,支持多种SQL注入技术,包括布尔型盲注、时间型盲注、报错型注入、联合查询注入和堆叠查询注入... 目录what支持类型how---less-1为例1.检测网站是否存在sql注入漏洞的注入点2.列举可用数据库3.列举数据库

Linux之软件包管理器yum详解

《Linux之软件包管理器yum详解》文章介绍了现代类Unix操作系统中软件包管理和包存储库的工作原理,以及如何使用包管理器如yum来安装、更新和卸载软件,文章还介绍了如何配置yum源,更新系统软件包... 目录软件包yumyum语法yum常用命令yum源配置文件介绍更新yum源查看已经安装软件的方法总结软

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Java访问修饰符public、private、protected及默认访问权限详解

《Java访问修饰符public、private、protected及默认访问权限详解》:本文主要介绍Java访问修饰符public、private、protected及默认访问权限的相关资料,每... 目录前言1. public 访问修饰符特点:示例:适用场景:2. private 访问修饰符特点:示例:

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

详解Java如何向http/https接口发出请求

《详解Java如何向http/https接口发出请求》这篇文章主要为大家详细介绍了Java如何实现向http/https接口发出请求,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 用Java发送web请求所用到的包都在java.net下,在具体使用时可以用如下代码,你可以把它封装成一

JAVA系统中Spring Boot应用程序的配置文件application.yml使用详解

《JAVA系统中SpringBoot应用程序的配置文件application.yml使用详解》:本文主要介绍JAVA系统中SpringBoot应用程序的配置文件application.yml的... 目录文件路径文件内容解释1. Server 配置2. Spring 配置3. Logging 配置4. Ma