本文主要是介绍3.deeplabv3+的深层网络结构的实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在第一篇文章中我们提到“在encoder部分,主要包括了backbone(DCNN)、ASPP两大部分”,在这里的backbone就是mobilenetv2网络结构和xception网络结构,而ASPP结构就是深层网络结构,其网络结构如下:
ASPP网络结构的原理其实很简单,可以看博文1.deeplabv3+网络结构及原理-CSDN博客,该博文有介绍。以上网络结构里的rate表示空洞卷积核的大小,显然,该网络结构总共5层卷积处理,之后再将不同的层用concat堆叠,最后再用1x1的卷积核整合特征,转换为图片中绿色的层。
下面深层网络结构的代码如下:
#-----------------------------------------#
# ASPP特征提取模块
# 利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):super(ASPP, self).__init__()self.branch1 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)self.branch2 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)self.branch3 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)self.branch4 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)self.branch5_relu = nn.ReLU(inplace=True)self.conv_cat = nn.Sequential(nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)def forward(self, x):[b, c, row, col] = x.size()# -----------------------------------------## 一共五个分支# -----------------------------------------#conv1x1 = self.branch1(x)conv3x3_1 = self.branch2(x)conv3x3_2 = self.branch3(x)conv3x3_3 = self.branch4(x)# -----------------------------------------## 第五个分支,全局平均池化+卷积# -----------------------------------------#global_feature = torch.mean(x, 2, True)global_feature = torch.mean(global_feature, 3, True)global_feature = self.branch5_conv(global_feature)global_feature = self.branch5_bn(global_feature)global_feature = self.branch5_relu(global_feature)global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)# -----------------------------------------## 将五个分支的内容堆叠起来# 然后1x1卷积整合特征。# -----------------------------------------#feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)result = self.conv_cat(feature_cat)return result
这篇关于3.deeplabv3+的深层网络结构的实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!