pytorch构建deeplabv3+

2024-03-27 03:38
文章标签 构建 pytorch deeplabv3

本文主要是介绍pytorch构建deeplabv3+,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DeepLab v3+ 是DeepLab语义分割系列网络的最新作,其前作有 DeepLab v1,v2, v3, 在最新作中,Liang-Chieh Chen等人通过encoder-decoder进行多尺度信息的融合,同时保留了原来的空洞卷积和ASSP层, 其骨干网络使用了Xception模型,提高了语义分割的健壮性和运行速率。其在Pascal VOC上达到了 89.0% 的mIoU,在Cityscape上也取得了 82.1%的好成绩,下图展示了DeepLab v3+的基本结构:

请添加图片描述
其实在DCNN中主要是做一个特征提取,至于采用哪个网络做backbone具体问题具体对待,在这里我才用的是mobilenetv2(只是将deepwise_conv中添加了dilation, 添加空洞卷积是为了增大感受野)

网络结构分为Encode部分和decoder部分
先看encoder部分:
请添加图片描述
接在DCNN后面的实际上就是一个ASPP结构(采用不同的采样率来对特征图做空洞卷积),然后再将对应的结果进行拼接,需要注意的是传入ASPP结构的是DCNN得到的高层特征图image Pooling部分其实会改变特征图的尺寸,所以可以通过使用双线插值(为什么采用双线插值,因为简单)或者其他方式保证经过ASPP结构的各个特征图尺寸相同,最后再进行拼接
请添加图片描述

再看decoder部分请添加图片描述
decoder部分首先会对传入的低层特征图进行通道调整,然后与encoder传入的特征图进行拼接,注意encoder传入的特征图需要经过上采样处理(维持与低层特征图相同的尺寸),最后输出部分只需要将尺寸还原到输入图片的尺寸就行了

import torch
import torch.nn as nn
import torch.functional as Fclass ASPP(nn.Module):def __init__(self, feature, atrous):super(ASPP, self).__init__()self.feature = featureself.Conv1 = _Deepwise_Conv(in_channels=feature.size()[1], out_channels=256, use_bias=False)self.Conv_rate1 = _Deepwise_Conv(in_channels=feature.size()[1], out_channels=256, rate=atrous[0],padding=atrous[0], use_bias=False)self.Conv_rate2 = _Deepwise_Conv(in_channels=feature.size()[1], out_channels=256, rate=atrous[1],padding=atrous[1], use_bias=False)self.Conv_rate3 = _Deepwise_Conv(in_channels=feature.size()[1], out_channels=256, rate=atrous[2],padding=atrous[2], use_bias=False)self.globalAvgPoolAndConv = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),Conv(in_channels=320, out_channels=256, kernel_size=1, stride=1, use_bias=False),)self.Conv4 = Conv(in_channels=256 * 5, out_channels=256, kernel_size=1, stride=1, use_bias=False)self.dropout = nn.Dropout(p=0.1)def forward(self):f1 = self.Conv1(self.feature.clone())f2 = self.Conv_rate1(self.feature.clone())f3 = self.Conv_rate2(self.feature.clone())f4 = self.Conv_rate3(self.feature.clone())f5 = self.globalAvgPoolAndConv(self.feature.clone())f5 = F.interpolate(f5, size=(self.feature.size(2), self.feature.size(3)), mode='bilinear')x = torch.cat([f1, f2, f3, f4, f5], dim=1)x = self.Conv4(x)x = self.dropout(x)class Deeplabv3(nn.Module):def __init__(self, feature, atrous, skip1, num_class):super(Deeplabv3, self).__init__()self.num_class = num_classself.feature = ASPP(atrous=atrous, feature=feature).forward()self.skip1 = skip1self.encoder = ASPP(atrous=atrous, feature=feature)self.Conv1 = Conv(in_channels=skip1.size()[1], out_channels=48, kernel_size=1,strip=1, use_bias=False)self.Conv2 = _Deepwise_Conv(in_channels=48 + 256, out_channels=256, use_bias=False)self.ConvNUM = Conv(in_channels=256, out_channels=num_class, kernel_size=1, use_bias=False)def forward(self, input_img):skip1 = self.Conv1(self.skip1)feature = F.interpolate(self.feature, size=(skip1.size()[2], skip1.size()[3]), mode='bilinear')skip1 = torch.cat([skip1, feature], dim=1)skip1 = self.Conv2(skip1)skip1 = self.ConvNUM(skip1)skip1 = F.interpolate(skip1, size=(input_img.size()[2], input_img.size()[3]))return F.softmax(skip1,dim=1)class _bottlenet(nn.Module):def __init__(self, in_channels, out_channels, rate=1, expand_ratio=1, stride=1):super(_bottlenet, self).__init__()# 步长为2以及前后通道数不同就不进行残差堆叠self.use_res_connect = (stride == 1) and (in_channels == out_channels)self.features = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=in_channels * expand_ratio, kernel_size=1),nn.BatchNorm2d(num_features=in_channels * expand_ratio),nn.ReLU6(inplace=True),nn.Conv2d(in_channels=in_channels * expand_ratio, out_channels=in_channels * expand_ratio, kernel_size=3, stride=stride,padding=rate, dilation=(rate, rate)),nn.BatchNorm2d(num_features=in_channels * expand_ratio),nn.ReLU6(inplace=True),nn.Conv2d(in_channels=in_channels * expand_ratio, out_channels=out_channels, stride=1, kernel_size=1,padding=0),nn.BatchNorm2d(num_features=out_channels),nn.ReLU6(inplace=True),)# self.change = nn.Conv2d()def forward(self, x):x_clone = x.clone()x = self.features(x)#         print(x.size())if self.use_res_connect:#             print("="*10)#             print(x.size())#             print(x_clone.size())x.add_(x_clone)return xclass get_mobilenetv2_encoder(nn.Module):def __init__(self, downsamp_factor=8, num_classes=3):super(get_mobilenetv2_encoder, self).__init__()if downsamp_factor == 8:self.atrous_rates = (12, 24, 36)block4_dilation = 2block5_dilation = 4block4_stride = 1else:self.atrous_rates = (6, 12, 18)block4_dilation = 1block5_dilation = 2block4_stride = 2self.features = []self.features.append(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=1, stride=2))self.features.append(nn.BatchNorm2d(num_features=32))self.features.append(nn.ReLU6(inplace=True))# ------  3 ------# block1self.features.append(_bottlenet(in_channels=32, out_channels=16, expand_ratio=1, stride=1))# block2# [t, c, n, s] = [6, 24, 2, 2]self.features.append(_bottlenet(in_channels=16, out_channels=24, expand_ratio=6, stride=2))self.features.append(_bottlenet(in_channels=24, out_channels=24, expand_ratio=6, stride=1))# ------  6  -----# block3# [t, c, n, s] = [6, 32, 3, 2]self.features.append(_bottlenet(in_channels=24, out_channels=32, expand_ratio=6, stride=2))for i in range(2):self.features.append(_bottlenet(in_channels=32, out_channels=32, expand_ratio=6))# ------  9  ------# block4# [t, c, n, s] = [6, 64, 4, 2]self.features.append(_bottlenet(in_channels=32, out_channels=64, expand_ratio=6, stride=block4_stride))for i in range(3):self.features.append(_bottlenet(in_channels=64, out_channels=64, expand_ratio=6, rate=block4_dilation))# ------  13  ------# block5# [t, c, n, s] = [6, 96, 3, 1]self.features.append(_bottlenet(in_channels=64, out_channels=96, expand_ratio=6, rate=block4_dilation))for i in range(2):self.features.append(_bottlenet(in_channels=96, out_channels=96, expand_ratio=6, rate=block4_dilation))# [t, c, n, s] = [6, 160, 3, 2]# block6self.features.append(_bottlenet(in_channels=96, out_channels=160, expand_ratio=6, stride=1))for i in range(2):self.features.append(_bottlenet(in_channels=160, out_channels=160, expand_ratio=6))# [t, c, n, s] = [6, 160, 3, 2]self.features.append(_bottlenet(in_channels=160, out_channels=320, expand_ratio=6))self.features = nn.Sequential(*self.features)def forward(self, x):skip1 = Nonefor i, op in enumerate(self.features, 0):x = op(x)if i == 5:skip1 = x.clone()return x, self.atrous_rates, skip1class pool_block(nn.Module):def __init__(self, f, stride):super(pool_block, self).__init__()in_channels = f.size()[1]kernel_size = strideself.features = nn.Sequential(nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size, padding=kernel_size // 2),nn.Conv2d(in_channels=in_channels, out_channels=512, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(num_features=512),nn.ReLU6(inplace=True),nn.Upsample(size=(INPUT_SIZE, INPUT_SIZE), mode="bilinear"))def forward(self, x):x = self.features(x)return xclass _Deepwise_Conv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, rate=1, use_bias=False):super(_Deepwise_Conv, self).__init__()self.conv1 = Conv(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,stride=stride, padding=padding, dilation=rate, use_bias=use_bias)self.conv2 = Conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1,stride=1, padding=0, use_bias=use_bias)def forward(self, x):return self.conv2(self.conv1(x))class Conv(nn.Module):'''nn.Conv2d + Batchnormlizetion + ReLU6'''def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, use_bias=False):super(Conv, self).__init__()self.features = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding=padding, dilation=dilation, bias=use_bias),nn.BatchNorm2d(num_features=out_channels),nn.ReLU6(),)def forward(self, x):return self.features(x)

参考链接如下:
https://blog.csdn.net/weixin_44791964/article/details/103017389
https://zhuanlan.zhihu.com/p/68531147

这篇关于pytorch构建deeplabv3+的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/850803

相关文章

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

maven 编译构建可以执行的jar包

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」👈,「stormsha的知识库」👈持续学习,不断总结,共同进步,为了踏实,做好当下事儿~ 专栏导航 Python系列: Python面试题合集,剑指大厂Git系列: Git操作技巧GO

嵌入式Openharmony系统构建与启动详解

大家好,今天主要给大家分享一下,如何构建Openharmony子系统以及系统的启动过程分解。 第一:OpenHarmony系统构建      首先熟悉一下,构建系统是一种自动化处理工具的集合,通过将源代码文件进行一系列处理,最终生成和用户可以使用的目标文件。这里的目标文件包括静态链接库文件、动态链接库文件、可执行文件、脚本文件、配置文件等。      我们在编写hellowor

利用命令模式构建高效的手游后端架构

在现代手游开发中,后端架构的设计对于支持高并发、快速迭代和复杂游戏逻辑至关重要。命令模式作为一种行为设计模式,可以有效地解耦请求的发起者与接收者,提升系统的可维护性和扩展性。本文将深入探讨如何利用命令模式构建一个强大且灵活的手游后端架构。 1. 命令模式的概念与优势 命令模式通过将请求封装为对象,使得请求的发起者和接收者之间的耦合度降低。这种模式的主要优势包括: 解耦请求发起者与处理者

Jenkins构建Maven聚合工程,指定构建子模块

一、设置单独编译构建子模块 配置: 1、Root POM指向父pom.xml 2、Goals and options指定构建模块的参数: mvn -pl project1/project1-son -am clean package 单独构建project1-son项目以及它所依赖的其它项目。 说明: mvn clean package -pl 父级模块名/子模块名 -am参数

JAVA用最简单的方法来构建一个高可用的服务端,提升系统可用性

一、什么是提升系统的高可用性 JAVA服务端,顾名思义就是23体验网为用户提供服务的。停工时间,就是不能向用户提供服务的时间。高可用,就是系统具有高度可用性,尽量减少停工时间。如何用最简单的方法来搭建一个高效率可用的服务端JAVA呢? 停工的原因一般有: 服务器故障。例如服务器宕机,服务器网络出现问题,机房或者机架出现问题等;访问量急剧上升,导致服务器压力过大导致访问量急剧上升的原因;时间和

利用Django框架快速构建Web应用:从零到上线

随着互联网的发展,Web应用的需求日益增长,而Django作为一个高级的Python Web框架,以其强大的功能和灵活的架构,成为了众多开发者的选择。本文将指导你如何从零开始使用Django框架构建一个简单的Web应用,并将其部署到线上,让世界看到你的作品。 Django简介 Django是由Adrian Holovaty和Simon Willison于2005年开发的一个开源框架,旨在简

828华为云征文|华为云Flexus X实例docker部署rancher并构建k8s集群

828华为云征文|华为云Flexus X实例docker部署rancher并构建k8s集群 华为云最近正在举办828 B2B企业节,Flexus X实例的促销力度非常大,特别适合那些对算力性能有高要求的小伙伴。如果你有自建MySQL、Redis、Nginx等服务的需求,一定不要错过这个机会。赶紧去看看吧! 什么是华为云Flexus X实例 华为云Flexus X实例云服务是新一代开箱即用、体

构建高性能WEB之HTTP首部优化

0x00 前言 在讨论浏览器优化之前,首先我们先分析下从客户端发起一个HTTP请求到用户接收到响应之间,都发生了什么?知己知彼,才能百战不殆。这也是作为一个WEB开发者,为什么一定要深入学习TCP/IP等网络知识。 0x01 到底发生什么了? 当用户发起一个HTTP请求时,首先客户端将与服务端之间建立TCP连接,成功建立连接后,服务端将对请求进行处理,并对客户端做出响应,响应内容一般包括响应