attention机制SENET、CBAM模块原理总结

2024-01-09 10:32

本文主要是介绍attention机制SENET、CBAM模块原理总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考博客:https://blog.csdn.net/weixin_33602281/article/details/85223216

def cbam_module(inputs,reduction_ratio=0.5,name=""):with tf.variable_scope("cbam_"+name, reuse=tf.AUTO_REUSE):#假如输入是[batsize,h,w,channel],#channel attension 因为要得到batsize * 1 * 1 * channel,它的全连接层第一层#隐藏层单元个数是channel / r, 第二层是channel,所以这里把channel赋值给hidden_numbatch_size,hidden_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]#通道attension#全局最大池化,窗口大小为h * w,所以对于这个数据[batsize,h,w,channel],他其实是求每个h * w面积的最大值#这里实现是先对h这个维度求最大值,然后对w这个维度求最大值,平均池化也一样maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)#上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层#这个拉平,它会保留batsize,所以结果是[batsize,channel]maxpool_channel = tf.layers.Flatten()(maxpool_channel)avgpool_channel = tf.layers.Flatten()(avgpool_channel)#将上面拉平后结果输入到全连接层,第一个全连接层hiddensize = channel/r = channel * reduction_ratio,#第二哥全连接层hiddensize = channelmlp_1_max=tf.layers.dense(inputs=maxpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=None,activation=tf.nn.relu)mlp_2_max=tf.layers.dense(inputs=mlp_1_max,units=hidden_num,name="mlp_2",reuse=None)#全连接层输出结果为[batsize,channel],这里又降它转回到原来维度batsize * 1 * 1 * channel,mlp_2_max=tf.reshape(mlp_2_max,[batch_size,1,1,hidden_num])mlp_1_avg=tf.layers.dense(inputs=avgpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=True,activation=tf.nn.relu)mlp_2_avg=tf.layers.dense(inputs=mlp_1_avg,units=hidden_num,name="mlp_2",reuse=True)mlp_2_avg=tf.reshape(mlp_2_avg,[batch_size,1,1,hidden_num])#将平均和最大池化的结果维度都是[batch_size,1,1,channel]相加,然后进行sigmod,维度不变channel_attention=tf.nn.sigmoid(mlp_2_max+mlp_2_avg)#和最开始的inputs相乘,相当于[batch_size,1,1,channel] * [batch_size,h,w,channel]#只有维度一样才能相乘,这里相乘相当于给每个通道作用了不同的权重channel_refined_feature=inputs*channel_attention#空间attension#上面得到的结果维度依然是[batch_size,h,w,channel],#下面要进行全局通道池化,其实就是一条通道里面那个通道的值最大,其实就是对channel这个维度求最大值#每个通道池化相当于将通道压缩到了1维,有两个池化,结果为两个[batch_size,h,w,1]feature mapmaxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)#将两个[batch_size,h,w,1]的feature map进行通道合并得到[batch_size,h,w,2]的feature mapmax_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)#然后对上面的feature map用1个7*7的卷积核进行卷积得到[batch_size,h,w,1]的feature map,因为是用一个卷积核卷的#所以将2个输入通道压缩到了1个输出通道conv_layer=tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", activation=None)#然后再对上面得到的[batch_size,h,w,1]feature map进行sigmod,这里为什么要用一个卷积核压缩到1个通道,相当于只得到了一个面积的值#然后进行sigmod,因为我们要求的就是feature map面积上不同位置像素的中重要性,所以它压缩到了一个通道,然后求sigmodspatial_attention=tf.nn.sigmoid(conv_layer)#上面得到了空间attension feature map [batch_size,h,w,1],然后再用这个和经过空间attension作用的结果相乘得到最终的结果#这个结果就是经过通道和空间attension共同作用的结果refined_feature=channel_refined_feature*spatial_attentionreturn refined_feature

这篇关于attention机制SENET、CBAM模块原理总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/586816

相关文章

Spring排序机制之接口与注解的使用方法

《Spring排序机制之接口与注解的使用方法》本文介绍了Spring中多种排序机制,包括Ordered接口、PriorityOrdered接口、@Order注解和@Priority注解,提供了详细示例... 目录一、Spring 排序的需求场景二、Spring 中的排序机制1、Ordered 接口2、Pri

Python利用自带模块实现屏幕像素高效操作

《Python利用自带模块实现屏幕像素高效操作》这篇文章主要为大家详细介绍了Python如何利用自带模块实现屏幕像素高效操作,文中的示例代码讲解详,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、获取屏幕放缩比例2、获取屏幕指定坐标处像素颜色3、一个简单的使用案例4、总结1、获取屏幕放缩比例from

MySQL 缓存机制与架构解析(最新推荐)

《MySQL缓存机制与架构解析(最新推荐)》本文详细介绍了MySQL的缓存机制和整体架构,包括一级缓存(InnoDBBufferPool)和二级缓存(QueryCache),文章还探讨了SQL... 目录一、mysql缓存机制概述二、MySQL整体架构三、SQL查询执行全流程四、MySQL 8.0为何移除查

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.

MySQL中的MVCC底层原理解读

《MySQL中的MVCC底层原理解读》本文详细介绍了MySQL中的多版本并发控制(MVCC)机制,包括版本链、ReadView以及在不同事务隔离级别下MVCC的工作原理,通过一个具体的示例演示了在可重... 目录简介ReadView版本链演示过程总结简介MVCC(Multi-Version Concurr

Python中连接不同数据库的方法总结

《Python中连接不同数据库的方法总结》在数据驱动的现代应用开发中,Python凭借其丰富的库和强大的生态系统,成为连接各种数据库的理想编程语言,下面我们就来看看如何使用Python实现连接常用的几... 目录一、连接mysql数据库二、连接PostgreSQL数据库三、连接SQLite数据库四、连接Mo

一文详解Java Condition的await和signal等待通知机制

《一文详解JavaCondition的await和signal等待通知机制》这篇文章主要为大家详细介绍了JavaCondition的await和signal等待通知机制的相关知识,文中的示例代码讲... 目录1. Condition的核心方法2. 使用场景与优势3. 使用流程与规范基本模板生产者-消费者示例

Git提交代码详细流程及问题总结

《Git提交代码详细流程及问题总结》:本文主要介绍Git的三大分区,分别是工作区、暂存区和版本库,并详细描述了提交、推送、拉取代码和合并分支的流程,文中通过代码介绍的非常详解,需要的朋友可以参考下... 目录1.git 三大分区2.Git提交、推送、拉取代码、合并分支详细流程3.问题总结4.git push

Kubernetes常用命令大全近期总结

《Kubernetes常用命令大全近期总结》Kubernetes是用于大规模部署和管理这些容器的开源软件-在希腊语中,这个词还有“舵手”或“飞行员”的意思,使用Kubernetes(有时被称为“... 目录前言Kubernetes 的工作原理为什么要使用 Kubernetes?Kubernetes常用命令总

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.