YOLOv7添加注意力机制和各种改进模块

2024-05-29 20:52

本文主要是介绍YOLOv7添加注意力机制和各种改进模块,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

YOLOv7添加注意力机制和各种改进模块代码免费下载:完整代码

添加的部分模块代码:

########CBAM
class ChannelAttentionModule(nn.Module):def __init__(self, c1, reduction=16):super(ChannelAttentionModule, self).__init__()mid_channel = c1 // reductionself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Linear(in_features=c1, out_features=mid_channel),nn.LeakyReLU(0.1, inplace=True),nn.Linear(in_features=mid_channel, out_features=c1))self.act = nn.Sigmoid()# self.act=nn.SiLU()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)return self.act(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.act = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.act(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, c1, c2):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out
##############CBAM
########SE
class SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)
########SE
#######GAM
class GAMAttention(nn.Module):# https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True, rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),kernel_size=7,padding=3),nn.BatchNorm2d(int(c1 / rate)),nn.ReLU(inplace=True),nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,kernel_size=7,padding=3),nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffleout = x * x_spatial_attreturn outdef channel_shuffle(x, groups=2):  ##shuffle channel# RESHAPE----->transpose------->FlattenB, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out = out.view(B, C, H, W)return out
#######GAM
#####NAMAttention  该注意力机制只有通道注意力机制的代码,空间的没有
import torch.nn as nn
import torch
from torch.nn import functional as Fclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual  #return xclass NAMAttention(nn.Module):def __init__(self, channels, out_channels=None, no_spatial=True):super(NAMAttention, self).__init__()self.Channel_Att = nn.Sequential(*(Channel_Att(channels)for _ in range(1)))def forward(self, x):# print(x.device)## device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')x_out1 = self.Channel_Att(x)return x_out1
#####NAMAttentionclass RepGhostBottleneck1(nn.Module):# RepGhostNeXt Bottleneckdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_outsuper().__init__()self.c_ = int(c2 * e)  # hidden channels# attention mechanism can be usedself.m = nn.Sequential(*(RepGhostBottleneck(c1, c2, 2*self.c_) for _ in range(n)))def forward(self, x):return self.m(x)

这篇关于YOLOv7添加注意力机制和各种改进模块的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1014615

相关文章

java中反射(Reflection)机制举例详解

《java中反射(Reflection)机制举例详解》Java中的反射机制是指Java程序在运行期间可以获取到一个对象的全部信息,:本文主要介绍java中反射(Reflection)机制的相关资料... 目录一、什么是反射?二、反射的用途三、获取Class对象四、Class类型的对象使用场景1五、Class

Python使用date模块进行日期处理的终极指南

《Python使用date模块进行日期处理的终极指南》在处理与时间相关的数据时,Python的date模块是开发者最趁手的工具之一,本文将用通俗的语言,结合真实案例,带您掌握date模块的六大核心功能... 目录引言一、date模块的核心功能1.1 日期表示1.2 日期计算1.3 日期比较二、六大常用方法详

python中time模块的常用方法及应用详解

《python中time模块的常用方法及应用详解》在Python开发中,时间处理是绕不开的刚需场景,从性能计时到定时任务,从日志记录到数据同步,时间模块始终是开发者最得力的工具之一,本文将通过真实案例... 目录一、时间基石:time.time()典型场景:程序性能分析进阶技巧:结合上下文管理器实现自动计时

Nginx之upstream被动式重试机制的实现

《Nginx之upstream被动式重试机制的实现》本文主要介绍了Nginx之upstream被动式重试机制的实现,可以通过proxy_next_upstream来自定义配置,具有一定的参考价值,感兴... 目录默认错误选择定义错误指令配置proxy_next_upstreamproxy_next_upst

Node.js net模块的使用示例

《Node.jsnet模块的使用示例》本文主要介绍了Node.jsnet模块的使用示例,net模块支持TCP通信,处理TCP连接和数据传输,具有一定的参考价值,感兴趣的可以了解一下... 目录简介引入 net 模块核心概念TCP (传输控制协议)Socket服务器TCP 服务器创建基本服务器服务器配置选项服

Spring排序机制之接口与注解的使用方法

《Spring排序机制之接口与注解的使用方法》本文介绍了Spring中多种排序机制,包括Ordered接口、PriorityOrdered接口、@Order注解和@Priority注解,提供了详细示例... 目录一、Spring 排序的需求场景二、Spring 中的排序机制1、Ordered 接口2、Pri

Python利用自带模块实现屏幕像素高效操作

《Python利用自带模块实现屏幕像素高效操作》这篇文章主要为大家详细介绍了Python如何利用自带模块实现屏幕像素高效操作,文中的示例代码讲解详,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、获取屏幕放缩比例2、获取屏幕指定坐标处像素颜色3、一个简单的使用案例4、总结1、获取屏幕放缩比例from

MySQL 缓存机制与架构解析(最新推荐)

《MySQL缓存机制与架构解析(最新推荐)》本文详细介绍了MySQL的缓存机制和整体架构,包括一级缓存(InnoDBBufferPool)和二级缓存(QueryCache),文章还探讨了SQL... 目录一、mysql缓存机制概述二、MySQL整体架构三、SQL查询执行全流程四、MySQL 8.0为何移除查

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.

一文详解Java Condition的await和signal等待通知机制

《一文详解JavaCondition的await和signal等待通知机制》这篇文章主要为大家详细介绍了JavaCondition的await和signal等待通知机制的相关知识,文中的示例代码讲... 目录1. Condition的核心方法2. 使用场景与优势3. 使用流程与规范基本模板生产者-消费者示例