PointNet代码学习(pytorch版本)

2024-02-28 15:58

本文主要是介绍PointNet代码学习(pytorch版本),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

源码地址

pointnet.pytorch
感谢大神!

代码结构

(pytorch) s@s:~/pointnet.pytorch$ tree -d
.
├── misc
├── pointnet
│   └── __pycache__
├── scripts
├── shapenetcore_partanno_segmentation_benchmark_v0
│   ├── 02691156
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02773838
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02954340
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 02958343
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03001627
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03261776
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03467517
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03624134
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03636649
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03642806
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03790512
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03797390
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 03948459
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04099429
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04225987
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   ├── 04379243
│   │   ├── points
│   │   ├── points_label
│   │   └── seg_img
│   └── train_test_split
└── utils├── cls├── __pycache__└── seg74 directories

utils

├── cls
│   ├── cls_model_0.pth
│   └── cls_model_1.pth
├── point_test.pts
├── __pycache__
│   ├── show3d_balls.cpython-36.pyc
│   ├── show3d_balls.cpython-37.pyc
│   └── show_seg.cpython-36.pyc
├── render_balls_so.cpp
├── render_balls_so.so
├── seg
│   ├── seg_model_Chair_0.pth
│   ├── seg_model_Chair_1.pth
│   ├── seg_model_Chair_2.pth
│   ├── seg_model_Chair_3.pth
│   └── seg_model_Chair_4.pth
├── show3d_balls.py
├── show_cls.py
├── show_points.py
├── show_seg.py
├── train_classification.py
└── train_segmentation.py3 directories, 19 files
简介
  • cls和seg文件夹下的是模型;
  • train_classification.py和train_segmentation.py是训练脚本
  • show_seg.py和 show_cls.py是导入模型进行测试并可视化的脚本
  • show3d_balls.py是可视化脚本,含有可视化相关的函数
  • show_points.py是自己写的测试的脚本(可以无视)
  • render_balls_so.cpp和render_balls_so.so是可视化渲染相关的库
  • point_test.pts是点云文件,自己测试用(不要理会)
代码详细注释
  • show_seg.py
'''
对原始点云进行分割,并可视化
例:python show_seg.py --model seg/seg_model_Chair_1.pth --class_choice Airplane --idx 2
'''from __future__ import print_function
from show3d_balls import showpoints
import argparse
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable#先把系统文件夹调到pointnet.pytorch下,防止找不到pointnet这个文件夹,
import sys
sys.path.append('/home/s/pointnet.pytorch')from pointnet.dataset import ShapeNetDataset
from pointnet.model import PointNetDenseCls
import matplotlib.pyplot as plt# 命令行解析
parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--idx', type=int, default=0, help='model index')
parser.add_argument('--dataset', type=str, default='', help='dataset path')
parser.add_argument('--class_choice', type=str, default='', help='class choice')# 输出一行状态栏参数如下:
# 				Namespace(class_choice='Airplane', dataset='', 
# 							idx=2, model='seg/seg_model_Chair_1.pth')
opt = parser.parse_args()		
# print("opt:{}".format(opt))# 数据预处理,得到某一类模型的集合
d = ShapeNetDataset(
#    root=opt.dataset,root='/home/s/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0',class_choice=[opt.class_choice],  #选择哪一类模型split='test',data_augmentation=0)idx = opt.idx
# print("d:{}".format(d))
print("model %d/%d" % (idx, len(d)))  #d代表全部的飞机数量
# model 2/341# print('dir(d):{}'.format(dir(d)))
# print('d[idx]:{}'.format(d[idx]))point, seg = d[idx]  #模型里的第idx的点云
print(point.size(), seg.size())     #seg代表每一个点的标签
# torch.Size([2500, 3]) torch.Size([2500])point_np = point.numpy() #将torch转为numpy# 可视化
cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
gt = cmap[seg.numpy() - 1, :]# 载入模型
state_dict = torch.load(opt.model)  
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0])
classifier.load_state_dict(state_dict)
classifier.eval()  #评估# 点云转置
point = point.transpose(1, 0).contiguous()
print('point.transpose(1, 0).shape: ',point.shape)point = Variable(point.view(1, point.size()[0], point.size()[1]))
print('--------------------')
print(point.dtype)
pred, _, _ = classifier(point)   #分割
# print('\npred.shape:',pred[0],'\n')
pred_choice = pred.data.max(2)[1]
print(pred_choice.numpy())   #输出每一个点的预测类别# print(pred_choice.size())
print(pred_choice.numpy()[0])  #[1 1 1 ... 1 1 1]
pred_color = cmap[pred_choice.numpy()[0], :]   #根据分类结果显示颜色
print('\npred_color: ',pred_color.shape,'\n')showpoints(point_np, gt, pred_color)  #pred_colord的为(2500, 3)的矩阵
print(point_np.shape)
  • show_cls.py
from __future__ import print_function  #使用python3的print函数
import argparse
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variableimport sys
sys.path.append('/home/s/pointnet.pytorch')from pointnet.dataset import ShapeNetDataset  #ShapeNetDataset是一个类 下面会实例化这个类
from pointnet.model import PointNetCls   #读入模型
import torch.nn.functional as F#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, default = '',  help='model path')
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')opt = parser.parse_args()
print(opt)# 实例化ShapeNetDataset类
test_dataset = ShapeNetDataset(
#    root='shapenetcore_partanno_segmentation_benchmark_v0',root='/home/s/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0',split='test',classification=True,npoints=opt.num_points,data_augmentation=False)# 读入测试数据
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)# 导入模型
classifier = PointNetCls(k=len(test_dataset.classes))
classifier.cuda()
classifier.load_state_dict(torch.load(opt.model))
classifier.eval()for i, data in enumerate(testdataloader, 0):points, target = datapoints, target = Variable(points), Variable(target[:, 0])points = points.transpose(2, 1)points, target = points.cuda(), target.cuda()pred, _, _ = classifier(points)  #进行分类loss = F.nll_loss(pred, target)  #计算损失函数# 计算准确率pred_choice = pred.data.max(1)[1]correct = pred_choice.eq(target.data).cpu().sum()print('i:%d  loss: %f accuracy: %f' % (i, loss.data.item(), correct / float(32))) 
  • show_points.py
'''
自己写的,用来测试
可视化文件夹下的点云数据
输入:n*3的矩阵
'''from __future__ import print_function
from show3d_balls import showpoints
import argparse
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
from pointnet.model import PointNetDenseCls# with open('./point_test.pts') as file:
#     for line in file:
#         print(len(line))import matplotlib.pylab as plt
import sys# from utils.show_seg import segsys.path.append('/home/s/pointnet.pytorch')# points=np.loadtxt('./point_test.pts')
points=np.loadtxt('./point_test.pts',dtype=np.float32)  #预测只能输入float32的格式的数据print(points.shape)# 可视化
cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
# gt = cmap[seg.numpy() - 1, :]# 可视化点云
showpoints(points)#采样到2500个点
choice = np.random.choice(len(points), 2500, replace=True)
# print('choice:{}'.format(choice))
points = points[choice, :]
print('points[choice, :]:{}'.format(points))
point_np=points# 载入模型
state_dict = torch.load('./seg/seg_model_Chair_1.pth')
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0])
classifier.load_state_dict(state_dict)
classifier.eval()  #设置为评估状态# 点云转置points=torch.from_numpy(points)
print(points.shape)
point = points.transpose(1, 0).contiguous()
print('point.transpose(1, 0).shape: ',point.shape)point = Variable(point.view(1, point.size()[0], point.size()[1]))  #转为torch变量1,3,2500
print('--------------------')print(point.dtype)
# point=torch.tensor(point,dtype=torch.float32)
pred, _, _ = classifier(point)   #分割
print(pred)pred_choice = pred.data.max(2)[1]
print(pred_choice.numpy())   #输出每一个点的预测类别# print(pred_choice.size())
print(pred_choice.numpy()[0])  #[1 1 1 ... 1 1 1]
pred_color = cmap[pred_choice.numpy()[0], :]   #根据分类结果显示颜色
print('\npred_color: ',pred_color.shape,'\n')
print(pred_color.dtype)# point_np=point.numpy().reshape(2500,3)
print(point_np.shape)
print(point_np.dtype)
showpoints(point_np, pred_color,pred_color)  #pred_colord的为(2500, 3)的矩阵
  • show3d_balls.py
'''
点云数据可视化
'''import numpy as np
import ctypes as ct
import cv2
import sys
showsz = 800
mousex, mousey = 0.5, 0.5
zoom = 1.0
changed = Truedef onmouse(*args):global mousex, mousey, changedy = args[1]x = args[2]mousex = x / float(showsz)mousey = y / float(showsz)changed = Truecv2.namedWindow('show3d')
cv2.moveWindow('show3d', 0, 0)
cv2.setMouseCallback('show3d', onmouse)dll = np.ctypeslib.load_library('render_balls_so', '.')# 该函数输入为n*3的矩阵
def showpoints(xyz,c_gt=None, c_pred = None, waittime=0, showrot=False, magnifyBlue=0, freezerot=False, background=(0,0,0), normalizecolor=True, ballradius=10):global showsz, mousex, mousey, zoom, changedxyz=xyz-xyz.mean(axis=0)radius=((xyz**2).sum(axis=-1)**0.5).max()xyz/=(radius*2.2)/showszif c_gt is None:c0 = np.zeros((len(xyz), ), dtype='float32') + 255c1 = np.zeros((len(xyz), ), dtype='float32') + 255c2 = np.zeros((len(xyz), ), dtype='float32') + 255else:c0 = c_gt[:, 0]c1 = c_gt[:, 1]c2 = c_gt[:, 2]if normalizecolor:c0 /= (c0.max() + 1e-14) / 255.0c1 /= (c1.max() + 1e-14) / 255.0c2 /= (c2.max() + 1e-14) / 255.0c0 = np.require(c0, 'float32', 'C')c1 = np.require(c1, 'float32', 'C')c2 = np.require(c2, 'float32', 'C')show = np.zeros((showsz, showsz, 3), dtype='uint8')def render():rotmat=np.eye(3)if not freezerot:xangle=(mousey-0.5)*np.pi*1.2else:xangle=0rotmat = rotmat.dot(np.array([[1.0, 0.0, 0.0],[0.0, np.cos(xangle), -np.sin(xangle)],[0.0, np.sin(xangle), np.cos(xangle)],]))if not freezerot:yangle = (mousex - 0.5) * np.pi * 1.2else:yangle = 0rotmat = rotmat.dot(np.array([[np.cos(yangle), 0.0, -np.sin(yangle)],[0.0, 1.0, 0.0],[np.sin(yangle), 0.0, np.cos(yangle)],]))rotmat *= zoomnxyz = xyz.dot(rotmat) + [showsz / 2, showsz / 2, 0]ixyz = nxyz.astype('int32')show[:] = backgrounddll.render_ball(ct.c_int(show.shape[0]), ct.c_int(show.shape[1]),show.ctypes.data_as(ct.c_void_p), ct.c_int(ixyz.shape[0]),ixyz.ctypes.data_as(ct.c_void_p), c0.ctypes.data_as(ct.c_void_p),c1.ctypes.data_as(ct.c_void_p), c2.ctypes.data_as(ct.c_void_p),ct.c_int(ballradius))if magnifyBlue > 0:show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=0))if magnifyBlue >= 2:show[:, :, 0] = np.maximum(show[:, :, 0],np.roll(show[:, :, 0], -1, axis=0))show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=1))if magnifyBlue >= 2:show[:, :, 0] = np.maximum(show[:, :, 0],np.roll(show[:, :, 0], -1, axis=1))if showrot:cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)),(30, showsz - 30), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)),(30, showsz - 50), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0,0.5, cv2.cv.CV_RGB(255, 0, 0))changed = Truewhile True:if changed:render()changed = Falsecv2.imshow('show3d', show)if waittime == 0:cmd = cv2.waitKey(10) % 256else:cmd = cv2.waitKey(waittime) % 256if cmd == ord('q'):breakelif cmd == ord('Q'):sys.exit(0)if cmd == ord('t') or cmd == ord('p'):if cmd == ord('t'):if c_gt is None:c0 = np.zeros((len(xyz), ), dtype='float32') + 255c1 = np.zeros((len(xyz), ), dtype='float32') + 255c2 = np.zeros((len(xyz), ), dtype='float32') + 255else:c0 = c_gt[:, 0]c1 = c_gt[:, 1]c2 = c_gt[:, 2]else:if c_pred is None:c0 = np.zeros((len(xyz), ), dtype='float32') + 255c1 = np.zeros((len(xyz), ), dtype='float32') + 255c2 = np.zeros((len(xyz), ), dtype='float32') + 255else:c0 = c_pred[:, 0]c1 = c_pred[:, 1]c2 = c_pred[:, 2]if normalizecolor:c0 /= (c0.max() + 1e-14) / 255.0c1 /= (c1.max() + 1e-14) / 255.0c2 /= (c2.max() + 1e-14) / 255.0c0 = np.require(c0, 'float32', 'C')c1 = np.require(c1, 'float32', 'C')c2 = np.require(c2, 'float32', 'C')changed = Trueif cmd==ord('n'):zoom*=1.1changed=Trueelif cmd==ord('m'):zoom/=1.1changed=Trueelif cmd==ord('r'):zoom=1.0changed=Trueelif cmd==ord('s'):cv2.imwrite('show3d.png',show)if waittime!=0:breakreturn cmdif __name__ == '__main__':np.random.seed(100)showpoints(np.random.randn(2500, 3))
  • train_classification.py
  • train_segmentation.py
'''
训练分割的网络模型
'''
# from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.dataimport sys
sys.path.append('/home/s/pointnet.pytorch')from pointnet.dataset import ShapeNetDataset
from pointnet.model import PointNetDenseCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdm
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')#每一次输入的32个元素,训练一遍需要输入 Size/batchSize 次 
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')#对整个数据集训练25次
parser.add_argument('--outf', type=str, default='seg', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice")
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")opt = parser.parse_args()
# print(opt)#使产生的随机数是确定的 保证结果是可以复现的
opt.manualSeed = random.randint(1, 10000)  # fix seed 产生一个随机点
# print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)   #本函数设置随机数的类型  没有返回值
torch.manual_seed(opt.manualSeed)dataset = ShapeNetDataset(root=opt.dataset,classification=False,class_choice=[opt.class_choice])
dataloader = torch.utils.data.DataLoader(dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))test_dataset = ShapeNetDataset(root=opt.dataset,classification=False,class_choice=[opt.class_choice],split='test',data_augmentation=False)
testdataloader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))# print(len(dataset), len(test_dataset))  #1958  341
num_classes = dataset.num_seg_classes   #num_seg_classes是一个数字,代表该类物体应该被分成几类,由class_choice在终端输入后根据字典找出来的
# print('-----classes', num_classes)  # 4
try:os.makedirs(opt.outf)
except OSError:passblue = lambda x: '\033[94m' + x + '\033[0m'  #设置显示的颜色classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform)   #分割网络if opt.model != '':   #如果模型存在就导入classifier.load_state_dict(torch.load(opt.model))optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # scheduler.step()调用step_size次,学习率才会调整一次
classifier.cuda()num_batch = len(dataset) / opt.batchSize   #计算需要几个batch才能导入datasetfor epoch in range(opt.nepoch):# scheduler.step()   #更新一下学习率 每step_size=20调整一次# print('----------lr:{}'.format(classifier.optimizer.state_dict()['param_groups'][0]['lr'] ) )for i, data in enumerate(dataloader, 0):  #i 代表第几个batch(最大值:num_batch),每个batch有batcSize个点云,每个点云2500个点# print('data:{}'.format(data))points, target = data# print('points.shape:{}'.format(points.size))# print('target.shape:{}'.format(target.size))points = points.transpose(2, 1)points, target = points.cuda(), target.cuda()optimizer.zero_grad()   #梯度归零classifier = classifier.train()  #模型设置为训练模式pred, trans, trans_feat = classifier(points)pred = pred.view(-1, num_classes)   #num_classes代表应该被分割成多少类  torch.Size([20000, 4])target = target.view(-1, 1)[:, 0] - 1# print(pred.size(), target.size())  #输出pred和target的格式  torch.Size([20000, 4]) torch.Size([20000])loss = F.nll_loss(pred, target)  #计算损失if opt.feature_transform:loss += feature_transform_regularizer(trans_feat) * 0.001loss.backward()  #反向传播损失optimizer.step()    #梯度下降优化 以batch为单位pred_choice = pred.data.max(1)[1]correct = pred_choice.eq(target.data).cpu().sum()print('[%d/%d: %d/%d] train loss: %f accuracy: %f' % (epoch,(opt.nepoch), i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500)))if i % 10 == 0:  #每10个batch执行一次,即为每10*batchSize个点云执行一次验证j, data = next(enumerate(testdataloader, 0))points, target = datapoints = points.transpose(2, 1)points, target = points.cuda(), target.cuda()classifier = classifier.eval()pred, _, _ = classifier(points)pred = pred.view(-1, num_classes)target = target.view(-1, 1)[:, 0] - 1loss = F.nll_loss(pred, target)# print('pred.shape:{}'.format(pred.size))pred_choice = pred.data.max(1)[1]correct = pred_choice.eq(target.data).cpu().sum()print('[%d/%d: %d/%d] %s loss: %f accuracy: %f' % (epoch,(opt.nepoch), i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))scheduler.step()   #更新一下学习率 每step_size=20调整一次torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch))   #每个epoch都保存一个模型# 计算精确度
## benchmark mIOU
shape_ious = []
for i,data in tqdm(enumerate(testdataloader, 0)):  # tqdm进度条points, target = datapoints = points.transpose(2, 1)points, target = points.cuda(), target.cuda()classifier = classifier.eval()   #评估pred, _, _ = classifier(points)print('\npred.shape:{}'.format(pred.shape))  #pred.shape:torch.Size([8, 2500, 4])pred_choice = pred.data.max(2)[1]pred_np = pred_choice.cpu().data.numpy()target_np = target.cpu().data.numpy() - 1print('target_np.shape:{}\n'.format(target_np.shape))  #target_np.shape:(8, 2500)for shape_idx in range(target_np.shape[0]):parts = range(num_classes)#np.unique(target_np[shape_idx])part_ious = []for part in parts:I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))if U == 0:iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1else:iou = I / float(U)part_ious.append(iou)shape_ious.append(np.mean(part_ious))print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious)))

pointnet

├── dataset.py
├── __init__.py
├── model.py
└── __pycache__├── dataset.cpython-36.pyc├── __init__.cpython-36.pyc└── model.cpython-36.pyc1 directory, 6 files
简介
  • dataset.py数据集导入相关
  • model.py模型构建相关
代码详细注释
  • dataset.py
# from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import sys
from tqdm import tqdm 
import json
from plyfile import PlyData, PlyElementdef get_segmentation_classes(root):catfile = os.path.join(root, 'synsetoffset2category.txt')cat = {}meta = {}with open(catfile, 'r') as f:for line in f:ls = line.strip().split()cat[ls[0]] = ls[1]for item in cat:dir_seg = os.path.join(root, cat[item], 'points_label')dir_point = os.path.join(root, cat[item], 'points')fns = sorted(os.listdir(dir_point))meta[item] = []for fn in fns:token = (os.path.splitext(os.path.basename(fn))[0])meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:for item in cat:datapath = []num_seg_classes = 0for fn in meta[item]:datapath.append((item, fn[0], fn[1]))for i in tqdm(range(len(datapath))):l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8)))if l > num_seg_classes:num_seg_classes = lprint("category {} num segmentation classes {}".format(item, num_seg_classes))f.write("{}\t{}\n".format(item, num_seg_classes))def gen_modelnet_id(root):classes = []with open(os.path.join(root, 'train.txt'), 'r') as f:for line in f:classes.append(line.strip().split('/')[0])classes = np.unique(classes)with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:for i in range(len(classes)):f.write('{}\t{}\n'.format(classes[i], i))# 导入ShapeNetDataset数据集的类
class ShapeNetDataset(data.Dataset):def __init__(self,root,npoints=2500,classification=False,class_choice=None,# split='test',split='train',data_augmentation=True):self.npoints = npointsself.root = rootself.catfile = os.path.join(self.root, 'synsetoffset2category.txt')# #输出catfile的名称# print('\n------------------------\nself.catfile:{}'.format(self.catfile),'\n-------------------------')self.cat = {}  #字典self.data_augmentation = data_augmentation  #默认1self.classification = classification  #默认0self.seg_classes = {}   #字典#读取synsetoffset2category.txtwith open(self.catfile, 'r') as f:for line in f:ls = line.strip().split()  #ls是一个list. ls: ['Airplane', '02691156']# print("------------------")# print('ls:',ls)# print("------------------")self.cat[ls[0]] = ls[1]#显示这个字典cat{}# print('self.cat:',self.cat)  #self.cat: {'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243'}# print('self.cat.items() :', self.cat.items())   #cat.items()返回一个list[],元素为元组().即[(),(),()....]if not class_choice is None:  #当class_choice不是None的时候才会执行self.cat = {k: v for k, v in self.cat.items() if k in class_choice}  #self.cat:{'Airplane': '02691156'}# print('-------------------------')# print('self.cat:{}'.format(self.cat))# print('-------------------------')self.id2cat = {v: k for k, v in self.cat.items()}# print('self.id2cat:{}'.format(self.id2cat))  #self.id2cat:{'02691156': 'Airplane'}self.meta = {} #字典splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))   #读取shuffled_train_file_list.json文件# print('splitfile:{}'.format(splitfile))#from IPython import embed; embed()filelist = json.load(open(splitfile, 'r')) #读取shuffled_train_file_list.json文件for item in self.cat:  #self.cat:{'Airplane': '02691156'}# print('item:{}'.format(item))  #item:Airplaneself.meta[item] = []# print('self.meta:{}'.format(self.meta))  #self.meta:{'Airplane': []}# print('self.cat.values():{}'.format(self.cat.values()))for file in filelist:# print('file:{}'.format(file))_, category, uuid = file.split('/')  #获取每一条的shape_data/04379243/9e3f1901ea14aca753315facdf531a34if category in self.cat.values():# print("category:{}".format(category))   #04379243self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),os.path.join(self.root, category, 'points_label', uuid+'.seg')))# print("self.meta:{}".format(self.meta))   #{'Airplane':[(),(),()....() ] }self.datapath = []  #listfor item in self.cat:  #self.cat:{'Airplane': '02691156'}# print('item:{}'.format(item))  #item:Airplanefor fn in self.meta[item]:self.datapath.append((item, fn[0], fn[1]))   #datapath:[(),()...],其中()为('Airplane','.pts','.seg')元组# print('----------------------')# print("datapath:{}".format(self.datapath))# print('----------------((------')# print('cat:{}'.format(self.cat))  #cat:{'Airplane': '02691156'}# print('len(self.cat):{}'.format(len(self.cat)))  #len(self.cat):1self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))  #使用zip函数, 把key和value的list组合在一起, 再转成字典(dict).# print(self.classes) #{'Airplane': 0}with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:for line in f:ls = line.strip().split()   #ls是一个list[]self.seg_classes[ls[0]] = int(ls[1]) #seg_classes是一个字典{}  ->seg_classes:{'Airplane': 4, 'Bag': 2, 'Cap': 2, 'Car': 4, 'Chair': 4, 'Earphone': 3, 'Guitar': 3, 'Knife': 2, 'Lamp': 4, 'Laptop': 2, 'Motorbike': 6, 'Mug': 2, 'Pistol': 3, 'Rocket': 3, 'Skateboard': 3, 'Table': 3}# print('seg_classes:{}'.format(self.seg_classes))# print('self.cat.keys():{}'.format(self.cat.keys()))  #self.cat.keys():dict_keys(['Airplane'])# print('list(self.cat.keys())[0]:{}'.format(list(self.cat.keys())[0]))  #转化为一个list->list(self.cat.keys()):['Airplane']self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]   #num_seg_classes为对应的的类应该分成几类# 输出:{'Airplane': 4, 'Bag': 2, 'Cap': 2, 'Car': 4, 'Chair': 4, 'Earphone': 3, 'Guitar': 3, 'Knife': 2, 'Lamp': 4, 'Laptop': 2, 'Motorbike': 6, 'Mug': 2, 'Pistol': 3, 'Rocket': 3, 'Skateboard': 3, 'Table': 3} 4print(self.seg_classes, list(self.cat.keys())[0],self.num_seg_classes)  #输出:一个字典{} 一个数字def __getitem__(self, index):fn = self.datapath[index]  #fn为元组()  ()为('Airplane','.pts','.seg')元组cls = self.classes[self.datapath[index][0]] #cls为classes{}的value# print('cls:{}'.format(cls))#读取点云和分类标签point_set = np.loadtxt(fn[1]).astype(np.float32)seg = np.loadtxt(fn[2]).astype(np.int64)# print('point_set.shape:{}_____ seg.shape:{}'.format(point_set.shape, seg.shape))  #point_set.shape:(2658, 3)_____ seg.shape:(2658,)# 重新采样到self.npoints个点choice = np.random.choice(len(seg), self.npoints, replace=True)#resamplepoint_set = point_set[choice, :]# print('point_set:{}'.format(point_set))# print('np.mean(point_set, axis = 0):{}'.format(np.mean(point_set, axis = 0)))  #中心点 1*3point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # 去中心化# print('new_point_set:{}'.format(point_set))dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)  #计算到原点的最远距离# print('dist:{}'.format(dist))point_set = point_set / dist #scale  归一化if self.data_augmentation:  #默认False  开启旋转任意角度并加上一个biastheta = np.random.uniform(0,np.pi*2)rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])# print('rotation_matrix:{}'.format(rotation_matrix))point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation# print('point_set:{}'.format(point_set.shape))point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitterseg = seg[choice]point_set = torch.from_numpy(point_set)seg = torch.from_numpy(seg)cls = torch.from_numpy(np.array([cls]).astype(np.int64))  #cls为对应的代号,比如Airplane对应0if self.classification:  #classification默认是0return point_set, clselse:return point_set, segdef __len__(self):# print('len(self.datapath):{}'.format(len(self.datapath)))return len(self.datapath)class ModelNetDataset(data.Dataset):def __init__(self,root,npoints=2500,split='train',data_augmentation=True):self.npoints = npointsself.root = rootself.split = splitself.data_augmentation = data_augmentationself.fns = []with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:for line in f:self.fns.append(line.strip())self.cat = {}with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:for line in f:ls = line.strip().split()self.cat[ls[0]] = int(ls[1])print(self.cat)self.classes = list(self.cat.keys())def __getitem__(self, index):fn = self.fns[index]cls = self.cat[fn.split('/')[0]]with open(os.path.join(self.root, fn), 'rb') as f:plydata = PlyData.read(f)pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).Tchoice = np.random.choice(len(pts), self.npoints, replace=True)point_set = pts[choice, :]point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # centerdist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)point_set = point_set / dist  # scaleif self.data_augmentation:theta = np.random.uniform(0, np.pi * 2)rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotationpoint_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitterpoint_set = torch.from_numpy(point_set.astype(np.float32))cls = torch.from_numpy(np.array([cls]).astype(np.int64))return point_set, clsdef __len__(self):return len(self.fns)if __name__ == '__main__':dataset = sys.argv[1]datapath = sys.argv[2]if dataset == 'shapenet':d = ShapeNetDataset(root = datapath, class_choice = ['Bag'])print('len(d):{}'.format(len(d)))ps, seg = d[0]print(ps.size(), ps.type(), seg.size(),seg.type())print('--------------------------------------------------------')d = ShapeNetDataset(root = datapath, classification = True)print('len(d):{}'.format(len(d)))ps, cls = d[0]print(ps.size(), ps.type(), cls.size(),cls.type())# get_segmentation_classes(datapath)if dataset == 'modelnet':gen_modelnet_id(datapath)d = ModelNetDataset(root=datapath)print(len(d))print(d[0])
  • model.py

  • 包含以下结构
    • 模型
      • STN3d
      • STNkd
      • PointNetfeat
      • PointNetCls ---- 分类网络
      • PointNetDenseCls ------分割网络
    • 函数
      • feature_transform_regularizer
# from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as Fclass STN3d(nn.Module):def __init__(self):super(STN3d, self).__init__()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, 3, 3)return xclass STNkd(nn.Module):def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k*k)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)self.k = kdef forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, self.k, self.k)return xclass PointNetfeat(nn.Module):def __init__(self, global_feat = True, feature_transform = False):super(PointNetfeat, self).__init__()self.stn = STN3d()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):n_pts = x.size()[2]trans = self.stn(x)x = x.transpose(2, 1)x = torch.bmm(x, trans)x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x)))if self.feature_transform:trans_feat = self.fstn(x)x = x.transpose(2,1)x = torch.bmm(x, trans_feat)x = x.transpose(2,1)else:trans_feat = Nonepointfeat = xx = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)if self.global_feat:return x, trans, trans_featelse:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans, trans_featclass PointNetCls(nn.Module):def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1), trans, trans_featclass PointNetDenseCls(nn.Module):def __init__(self, k = 2, feature_transform=False):super(PointNetDenseCls, self).__init__()self.k = kself.feature_transform=feature_transformself.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)# print('x:{}'.format(x.shape))  #x:torch.Size([32, 3, 2500])x = x.transpose(1,2).contiguous()   #transpose()交换1和2两个维度    contiguous()用来将tensor变为连续的# print('x.transpose(2,1):{}'.format(x.shape))  #x.transpose(2,1):torch.Size([32, 2500, 3])x = F.log_softmax(x.view(-1,self.k), dim=-1)# print('log_softmax:{}'.format(x.shape))  #torch.Size([20000, 4])x = x.view(batchsize, n_pts, self.k)# print('x.shape:{}'.format(x.shape))return x, trans, trans_featdef feature_transform_regularizer(trans):d = trans.size()[1]batchsize = trans.size()[0]I = torch.eye(d)[None, :, :]   # eye:产生一个单位矩阵if trans.is_cuda:I = I.cuda()loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))return lossif __name__ == '__main__':sim_data = Variable(torch.rand(32,3,2500))trans = STN3d()out = trans(sim_data)print('stn', out.size())print('loss', feature_transform_regularizer(out))sim_data_64d = Variable(torch.rand(32, 64, 2500))trans = STNkd(k=64)out = trans(sim_data_64d)print('stn64d', out.size())print('loss', feature_transform_regularizer(out))pointfeat = PointNetfeat(global_feat=True)out, _, _ = pointfeat(sim_data)print('global feat', out.size())pointfeat = PointNetfeat(global_feat=False)out, _, _ = pointfeat(sim_data)print('point feat', out.size())cls = PointNetCls(k = 5)out, _, _ = cls(sim_data)print('class', out.size())seg = PointNetDenseCls(k = 3)out, _, _ = seg(sim_data)print('seg', out.size())print(seg)

Reference

  • 言简意赅python系列—if not x: 和 if x is not None: 和 if not x is None: 的区别
  • Python 字典(Dictionary) items()方法
  • Python - 两个列表(list)组成字典(dict)
  • Python.__getitem__方法
  • python四个魔法方法__len__,getitem,setitem,delitem

开源框架PointNet 代码详解——/pointnet/sem_seg/train.py
【3D计算机视觉】从PointNet到PointNet++理论及pytorch代码

Tools

  • 每次5秒更新一次gpu状态
watch -n 5 nvidia-smi 

这篇关于PointNet代码学习(pytorch版本)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/755953

相关文章

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java调用Python代码的几种方法小结

《Java调用Python代码的几种方法小结》Python语言有丰富的系统管理、数据处理、统计类软件包,因此从java应用中调用Python代码的需求很常见、实用,本文介绍几种方法从java调用Pyt... 目录引言Java core使用ProcessBuilder使用Java脚本引擎总结引言python

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

SpringBoot使用注解集成Redis缓存的示例代码

《SpringBoot使用注解集成Redis缓存的示例代码》:本文主要介绍在SpringBoot中使用注解集成Redis缓存的步骤,包括添加依赖、创建相关配置类、需要缓存数据的类(Tes... 目录一、创建 Caching 配置类二、创建需要缓存数据的类三、测试方法Spring Boot 熟悉后,集成一个外

轻松掌握python的dataclass让你的代码更简洁优雅

《轻松掌握python的dataclass让你的代码更简洁优雅》本文总结了几个我在使用Python的dataclass时常用的技巧,dataclass装饰器可以帮助我们简化数据类的定义过程,包括设置默... 目录1. 传统的类定义方式2. dataclass装饰器定义类2.1. 默认值2.2. 隐藏敏感信息

opencv实现像素统计的示例代码

《opencv实现像素统计的示例代码》本文介绍了OpenCV中统计图像像素信息的常用方法和函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 统计像素值的基本信息2. 统计像素值的直方图3. 统计像素值的总和4. 统计非零像素的数量

你的华为手机升级了吗? 鸿蒙NEXT多连推5.0.123版本变化颇多

《你的华为手机升级了吗?鸿蒙NEXT多连推5.0.123版本变化颇多》现在的手机系统更新可不仅仅是修修补补那么简单了,华为手机的鸿蒙系统最近可是动作频频,给用户们带来了不少惊喜... 为了让用户的使用体验变得很好,华为手机不仅发布了一系列给力的新机,还在操作系统方面进行了疯狂的发力。尤其是近期,不仅鸿蒙O

什么是 Ubuntu LTS?Ubuntu LTS和普通版本区别对比

《什么是UbuntuLTS?UbuntuLTS和普通版本区别对比》UbuntuLTS是Ubuntu操作系统的一个特殊版本,旨在提供更长时间的支持和稳定性,与常规的Ubuntu版本相比,LTS版... 如果你正打算安装 Ubuntu 系统,可能会被「LTS 版本」和「普通版本」给搞得一头雾水吧?尤其是对于刚入

windows端python版本管理工具pyenv-win安装使用

《windows端python版本管理工具pyenv-win安装使用》:本文主要介绍如何通过git方式下载和配置pyenv-win,包括下载、克隆仓库、配置环境变量等步骤,同时还详细介绍了如何使用... 目录pyenv-win 下载配置环境变量使用 pyenv-win 管理 python 版本一、安装 和