attention机制SENET、CBAM模块原理总结

2024-01-09 10:32

本文主要是介绍attention机制SENET、CBAM模块原理总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考博客:https://blog.csdn.net/weixin_33602281/article/details/85223216

def cbam_module(inputs,reduction_ratio=0.5,name=""):with tf.variable_scope("cbam_"+name, reuse=tf.AUTO_REUSE):#假如输入是[batsize,h,w,channel],#channel attension 因为要得到batsize * 1 * 1 * channel,它的全连接层第一层#隐藏层单元个数是channel / r, 第二层是channel,所以这里把channel赋值给hidden_numbatch_size,hidden_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]#通道attension#全局最大池化,窗口大小为h * w,所以对于这个数据[batsize,h,w,channel],他其实是求每个h * w面积的最大值#这里实现是先对h这个维度求最大值,然后对w这个维度求最大值,平均池化也一样maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)#上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层#这个拉平,它会保留batsize,所以结果是[batsize,channel]maxpool_channel = tf.layers.Flatten()(maxpool_channel)avgpool_channel = tf.layers.Flatten()(avgpool_channel)#将上面拉平后结果输入到全连接层,第一个全连接层hiddensize = channel/r = channel * reduction_ratio,#第二哥全连接层hiddensize = channelmlp_1_max=tf.layers.dense(inputs=maxpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=None,activation=tf.nn.relu)mlp_2_max=tf.layers.dense(inputs=mlp_1_max,units=hidden_num,name="mlp_2",reuse=None)#全连接层输出结果为[batsize,channel],这里又降它转回到原来维度batsize * 1 * 1 * channel,mlp_2_max=tf.reshape(mlp_2_max,[batch_size,1,1,hidden_num])mlp_1_avg=tf.layers.dense(inputs=avgpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=True,activation=tf.nn.relu)mlp_2_avg=tf.layers.dense(inputs=mlp_1_avg,units=hidden_num,name="mlp_2",reuse=True)mlp_2_avg=tf.reshape(mlp_2_avg,[batch_size,1,1,hidden_num])#将平均和最大池化的结果维度都是[batch_size,1,1,channel]相加,然后进行sigmod,维度不变channel_attention=tf.nn.sigmoid(mlp_2_max+mlp_2_avg)#和最开始的inputs相乘,相当于[batch_size,1,1,channel] * [batch_size,h,w,channel]#只有维度一样才能相乘,这里相乘相当于给每个通道作用了不同的权重channel_refined_feature=inputs*channel_attention#空间attension#上面得到的结果维度依然是[batch_size,h,w,channel],#下面要进行全局通道池化,其实就是一条通道里面那个通道的值最大,其实就是对channel这个维度求最大值#每个通道池化相当于将通道压缩到了1维,有两个池化,结果为两个[batch_size,h,w,1]feature mapmaxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)#将两个[batch_size,h,w,1]的feature map进行通道合并得到[batch_size,h,w,2]的feature mapmax_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)#然后对上面的feature map用1个7*7的卷积核进行卷积得到[batch_size,h,w,1]的feature map,因为是用一个卷积核卷的#所以将2个输入通道压缩到了1个输出通道conv_layer=tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", activation=None)#然后再对上面得到的[batch_size,h,w,1]feature map进行sigmod,这里为什么要用一个卷积核压缩到1个通道,相当于只得到了一个面积的值#然后进行sigmod,因为我们要求的就是feature map面积上不同位置像素的中重要性,所以它压缩到了一个通道,然后求sigmodspatial_attention=tf.nn.sigmoid(conv_layer)#上面得到了空间attension feature map [batch_size,h,w,1],然后再用这个和经过空间attension作用的结果相乘得到最终的结果#这个结果就是经过通道和空间attension共同作用的结果refined_feature=channel_refined_feature*spatial_attentionreturn refined_feature

这篇关于attention机制SENET、CBAM模块原理总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/586816

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

git使用的说明总结

Git使用说明 下载安装(下载地址) macOS: Git - Downloading macOS Windows: Git - Downloading Windows Linux/Unix: Git (git-scm.com) 创建新仓库 本地创建新仓库:创建新文件夹,进入文件夹目录,执行指令 git init ,用以创建新的git 克隆仓库 执行指令用以创建一个本地仓库的

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

二分最大匹配总结

HDU 2444  黑白染色 ,二分图判定 const int maxn = 208 ;vector<int> g[maxn] ;int n ;bool vis[maxn] ;int match[maxn] ;;int color[maxn] ;int setcolor(int u , int c){color[u] = c ;for(vector<int>::iter

整数Hash散列总结

方法:    step1  :线性探测  step2 散列   当 h(k)位置已经存储有元素的时候,依次探查(h(k)+i) mod S, i=1,2,3…,直到找到空的存储单元为止。其中,S为 数组长度。 HDU 1496   a*x1^2+b*x2^2+c*x3^2+d*x4^2=0 。 x在 [-100,100] 解的个数  const int MaxN = 3000