【语义分割系列】PointRend源码注释

2023-10-28 15:50

本文主要是介绍【语义分割系列】PointRend源码注释,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

小白一个,理解错误欢迎大佬指正。下面的流程按语义分割框架deeplabv3 + PointRend做的注释。deeplabv3 的主干网络是xception65

原理图:代码主要流程看完下面的介绍再返回头来看看这张图应该就更清晰了.不过这个原理图和代码对应存在点问题。

代码中的fine-grained是原图的1/4大小,不像下面这个是与原图大小一致的。后面的就都一样了

      

                                                                                                        图1

1.PointRend提出原因:

    传统语义分割网络,在进行一系列卷积池化后。会得到一定分辨率的featuremap图。这个featuremap图一般大小为原图的  1/8    1/16或者1/32 等等吧,其上的点就有了类别标签了,知道了某个像素归属于某类。然后通过一定的上采样方法将其恢复到原图大小,这样就得到原图的语义分割结果了,可以想象,上采样后的物体边缘会有不准确情况。这个PointRend就是要修正下边缘。将featuremap上的点按照一定规则做了个不稳定性排序,然后找出最不稳定的N个点(认为其归属不明,边界混乱)对其精修。可见,这个方法是在某种语义分割的结果之上做的工作

2.PointRend训练流程:

a.对featuremap上的点做不稳定排序,选取N个点出来。代码中N是8096。

具体代码为:points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

b.在xception65的第一层上对应的N个点的特征提出来。

例如用的主干网络为xception65,那就以它为例。这个网络输出c1,c2,c3,c4。其中c1是较高分辨率下的featuremap(1/4),c4是最终的featuremap(1/16).将上面N个点在这两个图上的对应特征提出来。

具体代码为: coarse = point_sample(out, points, align_corners=False)
                         fine = point_sample(res2, points, align_corners=False)

c.将N个点的对应位置的特征粘合到一起。torch.cat函数实现 例如 C1的特征是[1, 19, 8096]  C2的特征是[1, 1256 8096] 那结果就是[1, 275, 8096]大小呗。

具体代码为:  eature_representation = torch.cat([coarse, fine], dim=1)

d.使用MLP进行细分预测。

具体代码为:  rend = self.mlp(feature_representation)

3.PointRend预测流程:

 与训练部分代码不同,在下面关键代码注释部分写了。

4.PointRend关键代码注释:

class PointHead(nn.Module):
    def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        """
        1. Fine-grained features are interpolated from res2 for DeeplabV3
        2. During training we sample as many points as there are on a stride 16 feature map of the input
        3. To measure prediction uncertainty
           we use the same strategy during training and inference: the difference between the most
           confident and second most confident class probabilities.
        """
        if not self.training:
            return self.inference(x, res2, out)

        points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)#提取点的位置

        coarse = point_sample(out, points, align_corners=False)#提C4特征位置 提取的是高级特征(深度深)
        fine = point_sample(res2, points, align_corners=False)#提C1特征位置  提取的是低级级特征(深度浅)

        feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合

        rend = self.mlp(feature_representation)#mlp预测识别  这些个点就被归属到不同类了

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, out):
        """
        During inference, subdivision uses N=8096
        (i.e., the number of points in the stride 16 map of a 1024×2048 image)
        """
        num_points = 8096
        #这块代码  输入的数据out是粗糙分类的结果,其是高层特征经过最终的21类的卷积得到的结果,可以看成是粗糙的语义分割结果,out 的shape 是类似[1, 21 , w, h ]形态  21 是类别数   w, h  是原图池化次后的大小,下面的代码就是不断对out上采样并且选其中的不稳定点做mlp预测,将预测结果替换out中的不稳定值。不断重复直到out尺寸与原图大小一致。
        while out.shape[-1] != x.shape[-1]:#直到将小图out插值到与原图x大小一致while循环结束
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)#先将高级特征out 做插值  乘以2

            points_idx, points = sampling_points(out, num_points, training=self.training)#在out 上提取不稳定点

            coarse = point_sample(out, points, align_corners=False)#同训练部分  提取不稳定点特征 在高级特征上做
            fine = point_sample(res2, points, align_corners=False)#同训练部分  提取不稳定点特征  在低级特征上做

            feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合

            rend = self.mlp(feature_representation)#同训练部分  rend的size是 [1, 21, 8096]    21是类别数  8096是点个数

            B, C, H, W = out.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)

            #这个函数的用法没弄太明白  但是功能不外乎就是将不确定点的新类别值去替换out中老类别的值
            out = (out.reshape(B, C, -1)  
                      .scatter_(2, points_idx, rend) #scatter_函数将rend中的数据根据points_idx索引填入out中
                      .view(B, C, H, W))

            
        return {"fine": out}

5.可运行的PointRend完整源码:

代码一共包含3个文件。

运行命令是:

python pointrend.py

1.这段代码放在deeplab.py文件中。

from collections import OrderedDictfrom torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead
#from .resnet import resnet103, resnet53
from torchvision.models import resnet50, resnet101from torchvision.models.resnet import ResNet, Bottleneck
import torch.nn as nnclass ResNetXX3(ResNet):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super().__init__(block, layers, num_classes, zero_init_residual,groups, width_per_group, replace_stride_with_dilation,norm_layer)self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')def resnet53(pretrained=False, progress=True, **kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return ResNetXX3(Bottleneck, [3, 4, 6, 3], **kwargs)def resnet103(pretrained=False, progress=True, **kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return ResNetXX3(Bottleneck, [3, 4, 23, 3], **kwargs)class SmallDeepLab(_SimpleSegmentationModel):def forward(self, input_):result = self.backbone(input_)result["coarse"] = self.classifier(result["out"])return resultdef deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21):resnet = {"res53":  resnet53,"res103": resnet103,"res50":  resnet50,"res101": resnet101}[resnet]net = SmallDeepLab(#IntermediateLayerGetter返回了resnet中的layer2和layer4,并封装成了新的名字'res2'和'out'backbone=IntermediateLayerGetter(resnet(pretrained=False, replace_stride_with_dilation=[False, True, True]),return_layers={'layer2': 'res2', 'layer4': 'out'}),classifier=DeepLabHead(head_in_ch, num_classes))return netif __name__ == "__main__":import torchx = torch.randn(3, 3, 512, 1024).cuda()net = deeplabv3(False).cuda()result = net(x)for k, v in result.items():print(k, v.shape)

2.这段代码放在sampling_points.py中

import torch
import torch.nn.functional as Fdef point_sample(input, point_coords, **kwargs):"""From Detectron2, point_features.py#19A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside[0, 1] x [0, 1] square.Args:input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains[0, 1] x [0, 1] normalized point coordinates.Returns:output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that containsfeatures for points in `point_coords`. The features are obtained via bilinearinterplation from `input` the same way as :function:`torch.nn.functional.grid_sample`."""add_dim = Falseif point_coords.dim() == 3:add_dim = Truepoint_coords = point_coords.unsqueeze(2)output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)if add_dim:output = output.squeeze(3)return output@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):"""Follows 3.1. Point Selection for Inference and TrainingIn Train:, `The sampling strategy selects N points on a feature map to train on.`In Inference, `then selects the N most uncertain points`Args:mask(Tensor): [B, C, H, W]N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`k(int): Over generation multiplierbeta(float): ratio of importance pointstraining(bool): flagReturn:selected_point(Tensor) : flattened indexing points [B, num_points, 2]"""assert mask.dim() == 4, "Dim must be N(Batch)CHW"device = mask.deviceB, _, H, W = mask.shapemask, _ = mask.sort(1, descending=True)if not training:H_step, W_step = 1 / H, 1 / WN = min(H * W, N)uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)points = torch.zeros(B, N, 2, dtype=torch.float, device=device)points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_steppoints[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_stepreturn idx, points# Official Comment : point_features.py#92# It is crucial to calculate uncertanty based on the sampled prediction value for the points.# Calculating uncertainties of the coarse predictions first and sampling them for points leads# to worse results. To illustrate the difference: a sampled point between two coarse predictions# with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one# calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the# center point, they will get -1 unceratinty.over_generation = torch.rand(B, k * N, 2, device=device)over_generation_map = point_sample(mask, over_generation, align_corners=False)uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])_, idx = uncertainty_map.topk(int(beta * N), -1)shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)idx += shift[:, None]importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)coverage = torch.rand(B, N - int(beta * N), 2, device=device)return torch.cat([importance, coverage], 1).to(device)

3.这段代码放在pointrend.py中


import torch
import torch.nn as nn
import torch.nn.functional as Ffrom sampling_points import sampling_points, point_sampleclass PointHead(nn.Module):def __init__(self, in_c=533, num_classes=21, k=3, beta=0.75):super().__init__()self.mlp = nn.Conv1d(in_c, num_classes, 1)self.k = kself.beta = betadef forward(self, x, res2, out):"""1. Fine-grained features are interpolated from res2 for DeeplabV32. During training we sample as many points as there are on a stride 16 feature map of the input3. To measure prediction uncertaintywe use the same strategy during training and inference: the difference between the mostconfident and second most confident class probabilities."""self.training = Falseif not self.training:return self.inference(x, res2, out)points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)#print("points", points.shape)  [3, 32, 2]   32 pointscoarse = point_sample(out, points, align_corners=False)fine = point_sample(res2, points, align_corners=False)feature_representation = torch.cat([coarse, fine], dim=1)print("feature_representation = ", feature_representation.shape)rend = self.mlp(feature_representation)#input shape  533 * 32  output shape 21 * 32return {"rend": rend, "points": points}@torch.no_grad()def inference(self, x, res2, out):"""During inference, subdivision uses N=8096(i.e., the number of points in the stride 16 map of a 1024×2048 image)"""num_points = 8096print("x = ", x.shape)print(" res2 = ", res2.shape)while out.shape[-1] != x.shape[-1]:out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)print("out old = ", out.shape)points_idx, points = sampling_points(out, num_points, training=self.training)coarse = point_sample(out, points, align_corners=False)fine = point_sample(res2, points, align_corners=False)feature_representation = torch.cat([coarse, fine], dim=1)rend = self.mlp(feature_representation)B, C, H, W = out.shapepoints_idx = points_idx.unsqueeze(1).expand(-1, C, -1)out = (out.reshape(B, C, -1).scatter_(2, points_idx, rend).view(B, C, H, W))print("out new = ", out.shape)return {"fine": out}class PointRend(nn.Module):def __init__(self, backbone, head):super().__init__()self.backbone = backboneself.head = headdef forward(self, x):result = self.backbone(x)print("x = ", x.shape)#print("result : %s" %  result)result.update(self.head(x, result["res2"], result["coarse"]))return resultif __name__ == "__main__":x = torch.randn(3, 3, 256, 512)from deeplab import deeplabv3print("6666666666666")net = PointRend(deeplabv3(False), PointHead())#print("net = ", net)out = net(x)for k, v in out.items():print("=========")print(k, v.shape)

这篇关于【语义分割系列】PointRend源码注释的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294276

相关文章

使用Python实现批量分割PDF文件

《使用Python实现批量分割PDF文件》这篇文章主要为大家详细介绍了如何使用Python进行批量分割PDF文件功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、架构设计二、代码实现三、批量分割PDF文件四、总结本文将介绍如何使用python进js行批量分割PDF文件的方法

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操

使用Python将长图片分割为若干张小图片

《使用Python将长图片分割为若干张小图片》这篇文章主要为大家详细介绍了如何使用Python将长图片分割为若干张小图片,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. python需求的任务2. Python代码的实现3. 代码修改的位置4. 运行结果1. Python需求

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

C#中字符串分割的多种方式

《C#中字符串分割的多种方式》在C#编程语言中,字符串处理是日常开发中不可或缺的一部分,字符串分割是处理文本数据时常用的操作,它允许我们将一个长字符串分解成多个子字符串,本文给大家介绍了C#中字符串分... 目录1. 使用 string.Split2. 使用正则表达式 (Regex.Split)3. 使用

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL