Pytorch之经典神经网络语义分割(3.2) —— ASPP 空洞空间金字塔池化(atrous spatial pyramid pooling )

本文主要是介绍Pytorch之经典神经网络语义分割(3.2) —— ASPP 空洞空间金字塔池化(atrous spatial pyramid pooling ),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

      ASPP是基于空洞卷积(Dilatd/Atrous Convolution)SPP(空间金字塔池化)的。用空洞卷积代替了单纯的adaptivepooling. ASPP对所给定的输入以不同采样率的空洞卷积并行采样,相当于以多个比例捕捉图像的上下文。

     

      ASPP实际上是空间金字塔池的一个版本,其中的概念已经在SPPNet中描述。在ASPP中,在输入特征映射中应用不同速率的并行空洞卷积,并融合在一起。由于同一类的物体在图像中可能有不同的比例,ASPP有助于考虑不同的物体比例,这可以提高准确性。

      DeepLab v2中就有用到ASPP模块

      这里设计了几种不同采样率的空洞卷积来捕捉多尺度信息,但我们要明白采样率(dilation rate)并不是越大越好,因为采样率太大,会导致滤波器有的会跑到padding上,产生无意义的权重,因此要选择合适的采样率。

Pytorch实现

import torch
from torch import nn
import torch.nn.functional as Fclass ASPP(nn.Module):def __init__(self, num_classes):super(ASPP, self).__init__()self.conv_1x1_1 = nn.Conv2d(2048, 256, kernel_size=1)self.bn_conv_1x1_1 = nn.BatchNorm2d(256)self.conv_3x3_6 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=6, dilation=6)self.bn_conv_3x3_6 = nn.BatchNorm2d(256)self.conv_3x3_12 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=12, dilation=12)self.bn_conv_3x3_12 = nn.BatchNorm2d(256)self.conv_3x3_18 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=18, dilation=18)self.bn_conv_3x3_18 = nn.BatchNorm2d(256)self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv_1x1_2 = nn.Conv2d(2048, 256, kernel_size=1)self.bn_conv_1x1_2 = nn.BatchNorm2d(256)self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256)self.bn_conv_1x1_3 = nn.BatchNorm2d(256)self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)def forward(self, feature_map):# (feature_map has shape (batch_size, 2048, h/8, w/8))feature_map_h = feature_map.size()[2] # (h/8)feature_map_w = feature_map.size()[3] # (w/8)out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 Eout_3x3_1 = F.relu(self.bn_conv_3x3_6(self.conv_3x3_6(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 Dout_3x3_2 = F.relu(self.bn_conv_3x3_12(self.conv_3x3_12(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 Cout_3x3_3 = F.relu(self.bn_conv_3x3_18(self.conv_3x3_18(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 B#out_1x1,out_3x3_1,out_3x3_2,out_3x3_3 的shape都一样out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))对应图中 ImagePoolingout_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/8, w/8))对应图中 Aout = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], dim=1) # (shape: (batch_size, 1280, h/8, w/8)) cat对应图中 F out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/8, w/8)) bn_conv_1x1_3对应图中 H  out 对应图中Iout = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/8, w/8))out 对应图中Upsample by 4return outif __name__ == '__main__':x = torch.rand(4,2048,28,28) #[b,c,h,w]aspp = ASPP(num_classes=10)out = aspp(x) #[b,num_class,h,w]

这篇关于Pytorch之经典神经网络语义分割(3.2) —— ASPP 空洞空间金字塔池化(atrous spatial pyramid pooling )的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/346964

相关文章

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

Python如何将大TXT文件分割成4KB小文件

《Python如何将大TXT文件分割成4KB小文件》处理大文本文件是程序员经常遇到的挑战,特别是当我们需要把一个几百MB甚至几个GB的TXT文件分割成小块时,下面我们来聊聊如何用Python自动完成这... 目录为什么需要分割TXT文件基础版:按行分割进阶版:精确控制文件大小完美解决方案:支持UTF-8编码

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

C++字符串提取和分割的多种方法

《C++字符串提取和分割的多种方法》在C++编程中,字符串处理是一个常见的任务,尤其是在需要从字符串中提取特定数据时,本文将详细探讨如何使用C++标准库中的工具来提取和分割字符串,并分析不同方法的适用... 目录1. 字符串提取的基本方法1.1 使用 std::istringstream 和 >> 操作符示

查看Oracle数据库中UNDO表空间的使用情况(最新推荐)

《查看Oracle数据库中UNDO表空间的使用情况(最新推荐)》Oracle数据库中查看UNDO表空间使用情况的4种方法:DBA_TABLESPACES和DBA_DATA_FILES提供基本信息,V$... 目录1. 通过 DBjavascriptA_TABLESPACES 和 DBA_DATA_FILES

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对