本文主要是介绍pytorch实现的通道注意力机制SENet的代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
通道注意力机制SENet
实现代码如下
import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)if __name__ == "__main__":t = torch.ones((32, 128, 26, 26))se = SELayer(channel=128, reduction=16)out = se(t)print(out.shape)
这篇关于pytorch实现的通道注意力机制SENet的代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!