pytorch代码实现注意力机制之SGE

2023-10-16 03:50

本文主要是介绍pytorch代码实现注意力机制之SGE,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

SGE注意力机制

SGE注意力机制是一种轻量attention模块,其亮点就是同时几乎不增加参数量和计算量的情况下也能让分类与检测性能得到极强的增益。同时,与其他attention模块相比,它是首个利用local与global的相似性作为attention mask的generation source,同时具有非常强的语义表示增强的可解释性。
SGE注意力模块通过在在每个group里生成attention factor,这样就能得到每个sub feature的重要性,每个group也可以有针对性的学习和抑制噪声。这个attention factor仅由各个group内全局和局部特征之间的相似性来决定,所以SGE非常轻量级。经由训练之后发现,SGE对于一些高阶语意非常有效。

论文地址:https://arxiv.org/pdf/1905.09646.pdf
![结构原理图](https://img-blog.csdnimg.cn/cb33e483e7134516a417195a51eaf6f5.png)

代码如下:

import numpy as np
import torch
from torch import nn
from torch.nn import initclass SpatialGroupEnhance(nn.Module):def __init__(self, groups=8):super().__init__()self.groups=groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.weight=nn.Parameter(torch.zeros(1,groups,1,1))self.bias=nn.Parameter(torch.zeros(1,groups,1,1))self.sig=nn.Sigmoid()self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, h,w=x.shapex=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,wxn=x*self.avg_pool(x) #bs*g,dim//g,h,wxn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,wt=xn.view(b*self.groups,-1) #bs*g,h*wt=t-t.mean(dim=1,keepdim=True) #bs*g,h*wstd=t.std(dim=1,keepdim=True)+1e-5t=t/std #bs*g,h*wt=t.view(b,self.groups,h,w) #bs,g,h*wt=t*self.weight+self.bias #bs,g,h*wt=t.view(b*self.groups,1,h,w) #bs*g,1,h*wx=x*self.sig(t)x=x.view(b,c,h,w)return x if __name__ == '__main__':input=torch.randn(50,512,7,7)sge = SpatialGroupEnhance(groups=8)output=sge(input)print(output.shape)

这篇关于pytorch代码实现注意力机制之SGE的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/218833

相关文章

利用c++判断水仙花数并输出示例代码

《利用c++判断水仙花数并输出示例代码》水仙花数是指一个三位数,其各位数字的立方和恰好等于该数本身,:本文主要介绍利用c++判断水仙花数并输出的相关资料,文中通过代码介绍的非常详细,需要的朋友可以... 以下是使用C++实现的相同逻辑代码:#include <IOStream>#include <vec

基于C++的UDP网络通信系统设计与实现详解

《基于C++的UDP网络通信系统设计与实现详解》在网络编程领域,UDP作为一种无连接的传输层协议,以其高效、低延迟的特性在实时性要求高的应用场景中占据重要地位,下面我们就来看看如何从零开始构建一个完整... 目录前言一、UDP服务器UdpServer.hpp1.1 基本框架设计1.2 初始化函数Init详解

Java中Map的五种遍历方式实现与对比

《Java中Map的五种遍历方式实现与对比》其实Map遍历藏着多种玩法,有的优雅简洁,有的性能拉满,今天咱们盘一盘这些进阶偏基础的遍历方式,告别重复又臃肿的代码,感兴趣的小伙伴可以了解下... 目录一、先搞懂:Map遍历的核心目标二、几种遍历方式的对比1. 传统EntrySet遍历(最通用)2. Lambd

springboot+redis实现订单过期(超时取消)功能的方法详解

《springboot+redis实现订单过期(超时取消)功能的方法详解》在SpringBoot中使用Redis实现订单过期(超时取消)功能,有多种成熟方案,本文为大家整理了几个详细方法,文中的示例代... 目录一、Redis键过期回调方案(推荐)1. 配置Redis监听器2. 监听键过期事件3. Redi

SpringBoot全局异常拦截与自定义错误页面实现过程解读

《SpringBoot全局异常拦截与自定义错误页面实现过程解读》本文介绍了SpringBoot中全局异常拦截与自定义错误页面的实现方法,包括异常的分类、SpringBoot默认异常处理机制、全局异常拦... 目录一、引言二、Spring Boot异常处理基础2.1 异常的分类2.2 Spring Boot默

基于SpringBoot实现分布式锁的三种方法

《基于SpringBoot实现分布式锁的三种方法》这篇文章主要为大家详细介绍了基于SpringBoot实现分布式锁的三种方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、基于Redis原生命令实现分布式锁1. 基础版Redis分布式锁2. 可重入锁实现二、使用Redisso

SpringBoo WebFlux+MongoDB实现非阻塞API过程

《SpringBooWebFlux+MongoDB实现非阻塞API过程》本文介绍了如何使用SpringBootWebFlux和MongoDB实现非阻塞API,通过响应式编程提高系统的吞吐量和响应性能... 目录一、引言二、响应式编程基础2.1 响应式编程概念2.2 响应式编程的优势2.3 响应式编程相关技术

Java 接口定义变量的示例代码

《Java接口定义变量的示例代码》文章介绍了Java接口中的变量和方法,接口中的变量必须是publicstaticfinal的,用于定义常量,而方法默认是publicabstract的,必须由实现类... 在 Java 中,接口是一种抽象类型,用于定义类必须实现的方法。接口可以包含常量和方法,但不能包含实例

C#实现将XML数据自动化地写入Excel文件

《C#实现将XML数据自动化地写入Excel文件》在现代企业级应用中,数据处理与报表生成是核心环节,本文将深入探讨如何利用C#和一款优秀的库,将XML数据自动化地写入Excel文件,有需要的小伙伴可以... 目录理解XML数据结构与Excel的对应关系引入高效工具:使用Spire.XLS for .NETC

Nginx更新SSL证书的实现步骤

《Nginx更新SSL证书的实现步骤》本文主要介绍了Nginx更新SSL证书的实现步骤,包括下载新证书、备份旧证书、配置新证书、验证配置及遇到问题时的解决方法,感兴趣的了解一下... 目录1 下载最新的SSL证书文件2 备份旧的SSL证书文件3 配置新证书4 验证配置5 遇到的http://www.cppc