本文主要是介绍PointNet++改进策略 :模块改进 | OE Unit | PointSIFT,结合方向信息提升模型精度,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 论文:PointSIFT: A SIFT-like Network Module for 3D Point Cloud Semantic Segmentation
- 来源:ECCV 2020
- 机构:清华大学 & 上海交通大学
- 论文:http://arxiv.org/abs/1807.00652
- 代码:
- https://github.com/MVIG-SJTU/pointSIFT/
- https://github.com/lelouedec/3DNetworksPytorch
- 灵感来源:来着SIFT方法,该方法是提取2D图像中图像特征的方法,来生成类似图像中形状的描述
- 方法实现:设计一个网络PointSIFT,能够提取物体中方向的信息同时能适应不同尺度的物体
- 实验结果:在ScanNet数据集上语义分割任务中,相比PointNet++网络IOU提升2.3%,Acc提升3.22%
网络整体
整体网络如下,PointSIFT对PointNet++中的MLP层进行了改进。
PointSIFT模块
通过堆叠多个 OE 单元,使得网络能够适应不同尺度的形状特征。在 PointSIFT 模块中堆叠多个 OE 单元,每个单元的感受野逐渐增大,以捕捉更大范围内的局部特征。使用shortcuts将各级 OE 单元的输出连接起来,然后通过一个点卷积(point-wise convolution)将多尺度特征融合,最终输出具有多尺度感知能力的特征表示。
Orientation-Encoding Unit
OE单元是一个方向编码单元来描述八个关键方向。
- 其中主要两个步骤:
- 八邻域搜索(Stacked 8-Neighborhood Search)对于每个输入点,PointSIFT 首先按照坐标轴将空间划分为八个象限,并在每个象限中找到距离最近的点作为邻居。如果某个象限内没有点,则将输入点自身复制为其最近邻。
方向编码卷积(Orientation-Encoding Convolution, OEC):在八邻域点上进行三阶段的卷积操作(沿 X, Y, Z 轴),将这些点的特征进行融合。具体来说,卷积操作将 2×2×2 立方体中的特征依次沿各轴进行卷积,最终输出包含方向信息的特征表示。
- 八邻域搜索(Stacked 8-Neighborhood Search)对于每个输入点,PointSIFT 首先按照坐标轴将空间划分为八个象限,并在每个象限中找到距离最近的点作为邻居。如果某个象限内没有点,则将输入点自身复制为其最近邻。
代码实现
我在文章顶端中,有俩个PointSIFT实现版本分别是pytorch和tensorflow实现,大家根据需求迁移到自己的项目中,下面我讲以pytorch版本网络设计对代码进行注释,具体的实现细节可以参考上面的链接中
class PointSIFT(nn.Module):def __init__(self, nb_classes):super(PointSIFT, self).__init__()self.num_classes = nb_classes# 第一个 PointSIFT 残差模块,用于提取局部特征,半径为 0.1,输出通道为 64。self.pointsift_res_m3 = PointSIFT_res_module(radius=0.1, output_channel=64, merge='concat')# 第一个 PointNet 下采样模块,采样 1024 个点,半径为 0.1,32 个邻居点,输出特征维度为 128。self.pointnet_sa_m3 = Pointnet_SA_module(npoint=1024, radius=0.1, nsample=32, in_channel=64, mlp=[64, 128], group_all=False)# 第二个 PointSIFT 残差模块,半径为 0.2,输出通道为 128。self.pointsift_res_m4 = PointSIFT_res_module(radius=0.2, output_channel=128, extra_input_channel=128)# 第二个 PointNet 下采样模块,采样 256 个点,半径为 0.2,32 个邻居点,输出特征维度为 256。self.pointnet_sa_m4 = Pointnet_SA_module(npoint=256, radius=0.2, nsample=32, in_channel=128, mlp=[128, 256], group_all=False)# 第三个 PointSIFT 残差模块,第一个子模块,半径为 0.2,输出通道为 256。self.pointsift_res_m5_1
这篇关于PointNet++改进策略 :模块改进 | OE Unit | PointSIFT,结合方向信息提升模型精度的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!