本文主要是介绍SKNet介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
https://blog.csdn.net/ITOMG/article/details/89673593(先看这个,理解sknet)
https://zhuanlan.zhihu.com/p/59690223(再看这个,原作者大局上的理解)
https://github.com/implus/PytorchInsight(pytorch实现)
对照下面的图像及代码基本就能理解他是怎么实现的
class Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = conv1x1(inplanes, planes)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = conv3x3(planes, planes, stride)self.bn2 = nn.BatchNorm2d(planes)self.conv2g = conv3x3(planes, planes, stride, groups = 32)self.bn2g = nn.BatchNorm2d(planes)self.conv3 = conv1x1(planes, planes * self.expansion)self.bn3 = nn.BatchNorm2d(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv_fc1 = nn.Conv2d(planes, planes//16, 1, bias=False)self.bn_fc1 = nn.BatchNorm2d(planes//16)self.conv_fc2 = nn.Conv2d(planes//16, 2 * planes, 1, bias=False)self.D = planesdef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)d1 = self.conv2(out)d1 = self.bn2(d1)d1 = self.relu(d1)d2 = self.conv2g(out)d2 = self.bn2g(d2)d2 = self.relu(d2)d = self.avg_pool(d1) + self.avg_pool(d2)d = F.relu(self.bn_fc1(self.conv_fc1(d)))d = self.conv_fc2(d)d = torch.unsqueeze(d, 1).view(-1, 2, self.D, 1, 1)d = F.softmax(d, 1)d1 = d1 * d[:, 0, :, :, :].squeeze(1)d2 = d2 * d[:, 1, :, :, :].squeeze(1)d = d1 + d2out = self.conv3(d)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
pytorch源码的实现好像与原文不太一样
第一个不同:
原文一个卷积是3*3,另一个尺度的卷积是3*3,dilation为2。
但是作者在实现的时候使用了groups,没有使用dilation,具体原因不明,具体如conv2g。
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2g = conv3x3(planes, planes, stride, groups = 32)
self.bn2g = nn.BatchNorm2d(planes)
第二个不同:
论文中是压缩提取模块,用的全连接的方式,但代码中他使用了2个1*1的卷积。
self.conv_fc1 = nn.Conv2d(planes, planes//16, 1, bias=False)self.bn_fc1 = nn.BatchNorm2d(planes//16)self.conv_fc2 = nn.Conv2d(planes//16, 2 * planes, 1, bias=False)
这篇关于SKNet介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!