【论文笔记】图像边缘精细分割 PointRend: Image Segmentation as Rendering

本文主要是介绍【论文笔记】图像边缘精细分割 PointRend: Image Segmentation as Rendering,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 1 综述
  • 2 PointRend Mudule解析
    • 2.1 点选择策略
    • 2.2 点的特征提取
    • 2.3 点的分类预测
  • 3 源代码解析
    • 3.1 Points Selection
    • 3.2 Point-wise Representation and Point Head
    • 3.3 Loss Function
  • 4 实验结果
  • 4 参考文献

1 综述

今天分享一篇何凯明2020年的论文《PointRend: Image Segmentation as Rendering》,文章主要解决的问题就是在图像分割任务中边缘不够精细的问题。

因为模型最容易误判的 pixel 基本上都在物体边缘, 边缘只占了整个物体中非常小的一部分。所以基于这样的一个想法,作者提出可以每次在预测出来的 mask 中只选择 Top N 最不确定的点位置进行单独预测,其他部分的像素点采用直接插值方法,这样就既可以解决了精度问题,还保证了内存与计算量尽可能的小。

论文地址:《PointRend: Image Segmentation as Rendering》

2 PointRend Mudule解析

PointRend 模块包含三个主要组件:

1、Point Selection Strategy:选择少量真值点执行预测,避免对高分辨率输出网格中的所有像素进行过度计算;

2、Point-wise feature Representation:使用每个选中点在 f 规则网格上的 4 个最近邻点,利用双线性内插计算真值点的特征。因此,该方法能够利用 f 的通道维度中编码的子像素信息,来预测比 f 规则网格分辨率高的分割预测;

3、Point Head:一个小型神经网络,用于基于逐点特征表示预测标签,它独立于每个点。每个细分点的特征可以通过 Bilinear 插值得到,每个位置上的 classifier 通过一个简单的MLP来实现。这其实是等价于用一个1*1 的 conv 来预测,但是对于中心很确定的点并不计算;

2.1 点选择策略

训练期间 PointRend 需要选择训练点来构造 point-wise features,以训练point head。原则上,点的选择策略可以类似于推理中使用的细分策略。但是, subdivision 引入了循环迭代,这对使用反向传播训练的神经网络不太友好。因此训练阶段使用了基于随机采样的非迭代策略;

推理阶段

在每次迭代中,PointRend使用双线性插值对之前预测的分割 Mask 进行上采样,然后在这个密度更大的网格上选择N个最不确定的点(例如,对于二值预测,概率接近0.5的点)。然后,PointRend为这N个点中的每一个点计算点特征表示,并预测它们的标签。这个过程是重复的,直到分割是上采样到所需的分辨率。一个coarse-to-fine的过程;
在这里插入图片描述

训练阶段

采用随机采样的非迭代策略来进行,具体如下:

(1):我们通过从均匀分布中随机抽样 kN 点(k>1)来过度生成候选点;

(2):从 kN 个点中选取βN(β ∈[0,1])个最不确定的点。使用0.5与概率之间的距离作为逐点不确定性度量,概率指的是对真实值的粗略预测概率。原文:(We use the distance between 0.5 and the probability of the ground truth class interpolated from the coarse prediction as the point-wise uncertainty measure.);

(3):在从均匀分布中选取 (1 - β)N 个点;
在这里插入图片描述

2.2 点的特征提取

PointRend 通过组合低层特征 (fine-grained features) 和高层特征 (coarse prediction),在选定的点上构造逐点特征。

Fine-grained features

为了让PointRend呈现出精细的分割细节,研究人员为CNN特征图中的每个采样点提取了特征向量。细粒度特征虽然可以解析细节,但也存在两方面的不足:
(1)不包含特定区域信息,对于实例分割任务,就可能在同一点上预测出不同的标签。
(2)用于细粒度特征的特征映射,可能仅包含相对较低级别的信息。

Coarse prediction features

来自于现有网络架构的输出,提供更多全局背景,用于对 fine-grained features 进行补充。以实例分割为例,coarse prediction可以是Mask R-CNN中 7×7 轻量级mask head的输出。
在这里插入图片描述

2.3 点的分类预测

通过一个多层感知机(MLP)来对每个被选中的点进行分类预测,所有点共享MLP的权重,MLP可以通过标准的任务特定的分段损失来训练。

3 源代码解析

3.1 Points Selection

def sampling_points(mask, N, k=3, beta=0.75, training=True):"""主要思想:根据粗糙的预测结果,找出不确定的像素点:param mask: 粗糙的预测结果(out)   eg.[2, 19, 48, 48]:param N: 不确定点个数(train:N = 图片的尺寸/16, test: N = 8096)    eg. N=48:param k: 超参:param beta: 超参:param training::return: 不确定点的位置坐标  eg.[2, 48, 2]"""assert mask.dim() == 4, "Dim must be N(Batch)CHW"   #this mask is out(coarse)device = mask.deviceB, _, H, W = mask.shape   #first: mask[1, 19, 48, 48]mask, _ = mask.sort(1, descending=True) #_ : [1, 19, 48, 48],按照每一类的总体得分排序if not training:H_step, W_step = 1 / H, 1 / WN = min(H * W, N)uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])#mask[:, 0]表示每个像素最有可能的分类,mask[:, 1]表示每个像素次有可能的分类,当一个像素#即是最有可能的又是次有可能的,则证明它不好预测,对应的uncertainty_map就相对较大_, idx = uncertainty_map.view(B, -1).topk(N, dim=1) #id选出最不好预测的N个点points = torch.zeros(B, N, 2, dtype=torch.float, device=device)points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_step    #点的横坐标points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step    #点的纵坐标return idx, points  #idx:48 || points:[1, 48, 2]

3.2 Point-wise Representation and Point Head

挑选出的不确定点所在图片的相对位置坐标来找到对应的特征点,将此点对应的特征向量与此点的粗糙预测结果合并,然后通过一个MLP进行细分预测,代码如下:

##训练阶段
def forward(self, x, res2, out):"""主要思路:通过 out(粗糙预测)计算出top N 个不稳定的像素点,针对每个不稳定像素点得到在res2(fine)和out(coarse)中对应的特征,组合N个不稳定像素点对应的fine和coarse得到rend,再通过mlp得到更准确的预测;:param x: 表示输入图片的特征     eg.[2, 3, 768, 768]:param res2: 表示xception的第一层特征输出     eg.[2, 256, 192, 192](下采样4倍):param out: 表示经过级联空洞卷积提取的特征的粗糙预测    eg.[2, 19, 48, 48](下采样16倍):return: rend:更准确的预测,points:不确定像素点的位置""""""1. Fine-grained features are interpolated from res2 for DeeplabV32. During training we sample as many points as there are on a stride 16 feature map of the input3. To measure prediction uncertaintywe use the same strategy during training and inference: the difference between the mostconfident and second most confident class probabilities."""if not self.training:return self.inference(x, res2, out)#获得不确定点的坐标points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta) #out:[2, 19, 48, 48] || x:[2, 3, 768, 768] || points:[2, 48, 2]#根据不确定点的坐标,得到对应的coarse feature;coarse = point_sample(out, points, align_corners=False) #[2, 19, 48]#根据不确定点的坐标,得到对应的fine feature;fine = point_sample(res2, points, align_corners=False)  #[2, 256, 48]#将对应的特征向量合并;feature_representation = torch.cat([coarse, fine], dim=1)   #[2, 275, 48]#使用MLP进行细分预测;rend = self.mlp(feature_representation) #[2, 19, 48]return {"rend": rend, "points": points}##推理阶段
@torch.no_grad()def inference(self, x, res2, out):"""输入:x:[1, 3, 768, 768],表示输入图片的特征res2:[1, 256, 192, 192],表示xception的第一层特征输出(下采样4倍)out:[1, 19, 48, 48],表示经过级联空洞卷积提取的特征的粗糙预测(下采样16倍)输出:out:[1,19,768,768],表示最终图片的预测主要思路:通过 out计算出top N = 8096 个不稳定的像素点,针对每个不稳定像素点得到在res2(fine)和out(coarse)中对应的特征,组合8096个不稳定像素点对应的fine和coarse得到rend,再通过mlp得到更准确的预测,迭代至rend的尺寸大小等于输入图片的尺寸大小""""""During inference, subdivision uses N=8096(i.e., the number of points in the stride 16 map of a 1024×2048 image)"""num_points = 8096while out.shape[-1] != x.shape[-1]: #out:[1, 19, 48, 48], x:[1, 3, 768, 768]#每一次预测均会扩大2倍像素,直至与原图像素大小一致out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)   #out[1, 19, 48, 48]points_idx, points = sampling_points(out, num_points, training=self.training)   #points_idx:8096 || points:[1, 8096, 2]coarse = point_sample(out, points, align_corners=False) #coarse:[1, 19, 8096]   表示8096个不稳定像素点根据高级特征得出的对应的类别fine = point_sample(res2, points, align_corners=False)  #fine:[1, 256, 8096]    表示8096个不稳定像素点根据低级特征得出的对应类别feature_representation = torch.cat([coarse, fine], dim=1)   #[1, 275, 8096] 表示8096个不稳定像素点合并fine和coarse的特征rend = self.mlp(feature_representation) #[1, 19, 8096]B, C, H, W = out.shape  #first:[1, 19, 128, 256]points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)  #[1, 19, 8096]out = (out.reshape(B, C, -1).scatter_(2, points_idx, rend)    #[1, 19, 32768].view(B, C, H, W))    #[1, 19, 128, 256]return {"fine": out}import torch.nn.functional as F
def point_sample(input, point_coords, **kwargs):"""A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside[0, 1] x [0, 1] square.Args:input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains[0, 1] x [0, 1] normalized point coordinates.Returns:output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that containsfeatures for points in `point_coords`. The features are obtained via bilinearinterplation from `input` the same way as :function:`torch.nn.functional.grid_sample`."""add_dim = Falseif point_coords.dim() == 3:add_dim = Truepoint_coords = point_coords.unsqueeze(2)output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)if add_dim:output = output.squeeze(3)return output

关于 torch.nn.functional.grid_sample 的说明,可点击查看!

3.3 Loss Function

由于有整体预测及细分点预测两部分,所以Loss也由这两部分加和而成,代码如下:

class PointRendLoss(nn.CrossEntropyLoss):def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):super(PointRendLoss, self).__init__(ignore_index=ignore_index)self.aux = auxself.aux_weight = aux_weightself.ignore_index = ignore_indexdef forward(self, *inputs, **kwargs):result, gt = tuple(inputs)#result['res2']: [2, 256, 192, 192], 即xception的c1层提取到的特征#result['coarse']: [2, 19, 48, 48]#result['rend']: [2, 19, 48]#result['points']:[2, 48, 2]#gt:[2, 768, 768], 即图片对应的label#pred:[2, 19, 768, 768],将粗糙预测的插值到label大小pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True)#整体像素点的交叉熵lossseg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index)#根据不确定点坐标获得不确定点对应的gtgt_points = point_sample(gt.float().unsqueeze(1),result["points"],mode="nearest",align_corners=False).squeeze_(1).long()#不确定点的交叉熵losspoints_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)#整体+不确定点loss = seg_loss + points_lossreturn dict(loss=loss)

4 实验结果

在各种定量的评测中,PointRend 均能提升1~2点的mask AP,而且展现出越强的backbone,越好的标注提升越高的特点。

其在 Cityscapes 样本上的实例分割和语义分割结果对比如下图
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

pointRend在语义分割中inference细节如下图:
在这里插入图片描述

4 参考文献

1、语义分割之PointRend论文与源码解读
2、PointRend
3、Ross、何恺明等人提出PointRend:渲染思路做图像分割,显著提升Mask R-CNN性能

这篇关于【论文笔记】图像边缘精细分割 PointRend: Image Segmentation as Rendering的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294277

相关文章

使用Python实现批量分割PDF文件

《使用Python实现批量分割PDF文件》这篇文章主要为大家详细介绍了如何使用Python进行批量分割PDF文件功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、架构设计二、代码实现三、批量分割PDF文件四、总结本文将介绍如何使用python进js行批量分割PDF文件的方法

基于WinForm+Halcon实现图像缩放与交互功能

《基于WinForm+Halcon实现图像缩放与交互功能》本文主要讲述在WinForm中结合Halcon实现图像缩放、平移及实时显示灰度值等交互功能,包括初始化窗口的不同方式,以及通过特定事件添加相应... 目录前言初始化窗口添加图像缩放功能添加图像平移功能添加实时显示灰度值功能示例代码总结最后前言本文将

使用Python将长图片分割为若干张小图片

《使用Python将长图片分割为若干张小图片》这篇文章主要为大家详细介绍了如何使用Python将长图片分割为若干张小图片,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. python需求的任务2. Python代码的实现3. 代码修改的位置4. 运行结果1. Python需求

C#中字符串分割的多种方式

《C#中字符串分割的多种方式》在C#编程语言中,字符串处理是日常开发中不可或缺的一部分,字符串分割是处理文本数据时常用的操作,它允许我们将一个长字符串分解成多个子字符串,本文给大家介绍了C#中字符串分... 目录1. 使用 string.Split2. 使用正则表达式 (Regex.Split)3. 使用

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear