CBAM注意力机制详解(附pytorch复现)

2024-03-01 07:12

本文主要是介绍CBAM注意力机制详解(附pytorch复现),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

论文原址:1807.06521.pdf (arxiv.org)

CBAM(Convolutional Block Attention Module)是一种卷积神经网络模块,旨在通过引入注意力机制来提升网络的表示能力。CBAM包含两个顺序子模块:通道注意力模块和空间注意力模块。

通过在深度网络的每个卷积块中自适应地优化中间特征图,CBAM通过强调通道和空间维度上的有意义特征,实现了对关键信息的关注和不必要信息的抑制。研究表明,CBAM在ImageNet-1K数据集上能够显著提高各种基线网络的准确性,通过grad-CAM可视化验证,CBAM增强的网络能够更准确地关注目标对象。在MS COCO和VOC 2007数据集上的目标检测任务中,CBAM也展现出显著的性能改进,而由于CBAM精心设计为轻量级模块,其在大多数情况下几乎没有参数和计算开销。CBAM注意力模块可广泛应用于提升卷积神经网络的表示能力。

Channel attention module(CAM)

通过平均池化和最大池化操作,整合输入特征图的空间信息,生成两个不同的空间上下文描述符,得到两个 1×1×C 的特征图,分别表示为 F_c_avg 和 F_c_max。将 F_c_avg 和 F_c_max 分别送入一个共享的多层感知机(MLP),该 MLP 具有一个隐藏层,其中第一层神经元个数为 C/r(r 为减少率),激活函数为 ReLU,第二层神经元个数为 C。这两层神经网络是共享的,即它们的权重相同。将两个 MLP 的输出特征进行逐元素相加,并通过 sigmoid 激活函数,生成通道注意力图 Mc。

这是对池化操作的使用进行实验比较的结果。研究者发现,采用平均池化和最大池化并行的方式能够取得更好的效果。可能是因为采用并行连接方式,相比于单一的池化,能够更有效地保留有用的信息,进而提升模型性能。

Spatial attention module(SAM)

首先,将 Channel Attention 模块输出的特征图作为 Spatial Attention 模块的输入特征图。接着,对输入特征图进行基于通道的全局最大池化和全局平均池化操作,得到两个 H×W×1 的特征图。然后,将这两个特征图在通道维度上进行拼接,经过一个 7×7 的卷积操作,将通道数降维为 1,即得到 H×W×1 的特征图。最后,经过 sigmoid 操作生成空间注意力特征,即 Ms。将该特征与输入特征图进行乘法操作,得到最终生成的特征。这一过程有助于模型关注输入特征图中的重要区域,从而增强表示能力。

CBAM的pytorch实现

"""
Original paper addresshttps: https://arxiv.org/pdf/1807.06521.pdf
Time: 2024-02-28
"""
import torch
from torch import nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, reduction=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)# shared MLPself.mlp = nn.Sequential(nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.mlp(self.avg_pool(x))max_out = self.mlp(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7, padding=3):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class CBAM(nn.Module):def __init__(self, in_planes, reduction=16, kernel_size=7):super(CBAM, self).__init__()self.ca = ChannelAttention(in_planes, reduction)self.sa = SpatialAttention(kernel_size)def forward(self, x):out = x * self.ca(x)result = out * self.sa(out)return resultif __name__ == '__main__':block = CBAM(16)input = torch.rand(1, 16, 8, 8)output = block(input)print(output.shape)

参考文章

CBAM——即插即用的注意力模块(附代码)_cbam模块-CSDN博客

[ 注意力机制 ] 经典网络模型2——CBAM 详解与复现_cbam代码复现-CSDN博客

这篇关于CBAM注意力机制详解(附pytorch复现)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/761602

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,

如何通俗理解注意力机制?

1、注意力机制(Attention Mechanism)是机器学习和深度学习中一种模拟人类注意力的方法,用于提高模型在处理大量信息时的效率和效果。通俗地理解,它就像是在一堆信息中找到最重要的部分,把注意力集中在这些关键点上,从而更好地完成任务。以下是几个简单的比喻来帮助理解注意力机制: 2、寻找重点:想象一下,你在阅读一篇文章的时候,有些段落特别重要,你会特别注意这些段落,反复阅读,而对其他部分

嵌入式Openharmony系统构建与启动详解

大家好,今天主要给大家分享一下,如何构建Openharmony子系统以及系统的启动过程分解。 第一:OpenHarmony系统构建      首先熟悉一下,构建系统是一种自动化处理工具的集合,通过将源代码文件进行一系列处理,最终生成和用户可以使用的目标文件。这里的目标文件包括静态链接库文件、动态链接库文件、可执行文件、脚本文件、配置文件等。      我们在编写hellowor