深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果

2023-12-30 16:48

本文主要是介绍深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.Pcdet-PointPillars原始模型结构

     网络部分包含4部分:
    (1)PillarVFE
    (2)PointPillarScatter
    (3)BaseBEVBackbone
    (4)AnchorHeadSingle

主要对BaseBEVBackbone部分剪枝,BaseBEVBackbone网络结构图如下:
在这里插入图片描述
具体如下:

  (backbone_2d): BaseBEVBackbone((blocks): ModuleList((0): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU())(1): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU()(13): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(14): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(15): ReLU()(16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(17): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(18): ReLU())(2): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU()(13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(14): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(15): ReLU()(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(17): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(18): ReLU()))(deblocks): ModuleList((0): Sequential((0): ConvTranspose2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())(1): Sequential((0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())(2): Sequential((0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())))

2.剪枝

2.1稀疏训练
    对BN层的参数进行诱导,让大部分参数趋于零,降低剪枝对模型精度的影响

loss.backward()
updateBN(model)
optimizer.step()
def updateBN(model):s = 0.0001for m in model.modules():if isinstance(m, torch.nn.BatchNorm2d):m.weight.grad.data.add_(s*torch.sign(m.weight.data))  # L1

2.2对稀疏训练后的模型剪枝-Network_Slimming

(1)根据剪枝率(percent)计算阈值

    total = 0for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total)index = 0for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)thre_index = int(total * 0.7)thre = y[thre_index]

(2)生成cfg_index(通道剪枝个数索引列表)与cfg_mask

    pruned = 0cfg_index = []cfg_mask = []for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.abs().clone()mask = weight_copy.cpu().gt(thre).float().cuda()#pdb.set_trace()pruned = pruned + mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask)m.bias.data.mul_(mask)cfg_index.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg_index.append('M')

(3)对不想剪枝的bn层,cfg_mask该bn层参数全部置1

例如:

cfg_mask[0][:]=1   ##对第一个bn不剪枝

注:1)应该有更好的办法,具体问题具体分析,现在只是实现了
       2)有很多层不能剪枝,请注意

(4)根据cfg_index构建剪枝后模型框架

newmodel = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set, cfg_index=cfg_index)
newmodel = newmodel.to(device='cuda:0')

注意,此处的build_network需要改写,我主要是剪枝BaseBEVBackbone,所以将此模块的每个卷积层的输入输出尺寸与cfg_index对应,如下:

class BaseBEVBackbone(nn.Module):def __init__(self, model_cfg, input_channels, cfg_index=None):super().__init__()self.model_cfg = model_cfgif self.model_cfg.get('LAYER_NUMS', None) is not None:assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)layer_nums = self.model_cfg.LAYER_NUMSlayer_strides = self.model_cfg.LAYER_STRIDESnum_filters = self.model_cfg.NUM_FILTERSelse:layer_nums = layer_strides = num_filters = []if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERSupsample_strides = self.model_cfg.UPSAMPLE_STRIDESelse:upsample_strides = num_upsample_filters = [] num_levels = len(layer_nums)c_in_list = [input_channels, *num_filters[:-1]]self.blocks = nn.ModuleList()self.deblocks = nn.ModuleList()if cfg_index is None:cfg_index = [64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 128, 128, 128]cfg=cfg_indexfor idx in range(num_levels): if idx == 0:cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(64,64, kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(3):if k ==0:cur_layers.extend([nn.Conv2d(64, cfg[k+1], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),nn.ReLU()])if k ==1:cur_layers.extend([nn.Conv2d(cfg[k+0], cfg[k+1], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),nn.ReLU()])if k ==2:cur_layers.extend([nn.Conv2d(cfg[k+0], 64, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),nn.ReLU()])elif idx ==1 :cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(64, cfg[4], kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(cfg[4], eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(5):if k ==4:cur_layers.extend([nn.Conv2d(cfg[k+4], 128, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(128, eps=1e-3, momentum=0.01),nn.ReLU()])else:cur_layers.extend([nn.Conv2d(cfg[k+4], cfg[k+5], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+5], eps=1e-3, momentum=0.01),nn.ReLU()])elif idx ==2 :cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(128, cfg[10], kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(cfg[10], eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(5):if k==4:cur_layers.extend([nn.Conv2d(cfg[k+10], 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256, eps=1e-3, momentum=0.01),nn.ReLU()])else:cur_layers.extend([nn.Conv2d(cfg[k+10], cfg[k+11], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+11], eps=1e-3, momentum=0.01),nn.ReLU()])self.blocks.append(nn.Sequential(*cur_layers))if len(upsample_strides) > 0:stride = upsample_strides[idx]if stride >= 1:self.deblocks.append(nn.Sequential(nn.ConvTranspose2d(num_filters[idx], num_upsample_filters[idx],upsample_strides[idx],stride=upsample_strides[idx], bias=False),nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),nn.ReLU()))else:stride = np.round(1 / stride).astype(np.int)self.deblocks.append(nn.Sequential(nn.Conv2d(num_filters[idx], num_upsample_filters[idx],stride,stride=stride, bias=False),nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),nn.ReLU()))c_in = sum(num_upsample_filters)if len(upsample_strides) > num_levels:self.deblocks.append(nn.Sequential(nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),nn.ReLU(),))self.num_bev_features = c_in########testdef forward(self, data_dict):"""Args:data_dict:spatial_featuresReturns:"""spatial_features = data_dict['spatial_features'] ups = []ret_dict = {}x = spatial_featuresfor i in range(len(self.blocks)):         x = self.blocks[i](x)            stride = int(spatial_features.shape[2] / x.shape[2])ret_dict['spatial_features_%dx' % stride] = xif len(self.deblocks) > 0:ups.append(self.deblocks[i](x))else:ups.append(x)if len(ups) > 1:x = torch.cat(ups, dim=1)elif len(ups) == 1:x = ups[0]if len(self.deblocks) > len(self.blocks):x = self.deblocks[-1](x)data_dict['spatial_features_2d'] = xreturn data_dict

(5)对conv层及bn层参数进行剪枝

    old_modules = list(model.modules())new_modules = list(newmodel.modules())layer_id_in_cfg = 0start_mask = torch.ones(64)end_mask = cfg_mask[layer_id_in_cfg]conv_count = 0bn_count = 0for layer_id in range(len(old_modules)):m0 = old_modules[layer_id]m1 = new_modules[layer_id]#print("old_modules  is: ", old_modules)if isinstance(m0, nn.BatchNorm2d):idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))if idx1.size == 1:idx1 = np.resize(idx1,(1,))if bn_count == 0 :# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.m1.weight.data = m0.weight.data.clone()m1.bias.data = m0.bias.data.clone()m1.running_mean = m0.running_mean.clone()m1.running_var = m0.running_var.clone()bn_count += 1layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):end_mask = cfg_mask[layer_id_in_cfg]else:bn_count += 1m1.weight.data = m0.weight.data[idx1.tolist()].clone()m1.bias.data = m0.bias.data[idx1.tolist()].clone()m1.running_mean = m0.running_mean[idx1.tolist()].clone()m1.running_var = m0.running_var[idx1.tolist()].clone()layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg]elif isinstance(m0, nn.Conv2d):if conv_count == 0:m1.weight.data = m0.weight.data.clone()conv_count += 1continueif layer_id == (len(old_modules)-1):m1.weight.data = m0.weight.data.clone()continueif isinstance(old_modules[layer_id+1], nn.BatchNorm2d):# This convers the convolutions in the residual block.# The convolutions are either after the channel selection layer or after the batch normalization layer.conv_count += 1idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()# If the current convolution is not the last convolution in the residual block, then we can change the # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.w1 = w1[idx1.tolist(), :, :, :].clone()m1.weight.data = w1.clone()torch.save({'cfg': cfg, 'model_state': newmodel.state_dict()}, os.path.join('./', 'pruned_90.pth'))

(6)使用新模型,并load参数,测试效果

3.测试步骤

首先数据预处理:

python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml

然后:
(1).环境位置:阵列g03,zxw_compression容器,/data/OpenPCDet-master/tools
(2).运行命令:
    测试:

CUDA_VISIBLE_DEVICES=2 python test.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 1 --ckpt checkpoint_epoch_90.pth

    训练:

python train.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 8 --epochs 100

(3)每次切换一个OpenPCDet-master,需要运行命令

python setup.py develop

4.剪枝结果

在这里插入图片描述
在这里插入图片描述

5.总结

(1)不是所有的模型都能用剪枝来加速,很多网络层没有BN,或者可剪枝的层数过少,增速不明显;
(2)从结果来看,PCDET中的pointpillars网络部分耗时很少,主要时间浪费在后处理中的NMS模块,还未深入研究此模块耗时原因;
(3)剪枝Backbone2d层会减少后处理速度,有待研究;
(4)对于有大量conv2d+bn组合的网络结构,网络层数较多的,例如resnet152,可以采用剪枝来加速。注:本文所说的剪枝,指的是根据bn参数,对通道剪枝,不涉及其他剪枝

这篇关于深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/553455

相关文章

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

SpringBoot使用minio进行文件管理的流程步骤

《SpringBoot使用minio进行文件管理的流程步骤》MinIO是一个高性能的对象存储系统,兼容AmazonS3API,该软件设计用于处理非结构化数据,如图片、视频、日志文件以及备份数据等,本文... 目录一、拉取minio镜像二、创建配置文件和上传文件的目录三、启动容器四、浏览器登录 minio五、

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

Nginx、Tomcat等项目部署问题以及解决流程

《Nginx、Tomcat等项目部署问题以及解决流程》本文总结了项目部署中常见的four类问题及其解决方法:Nginx未按预期显示结果、端口未开启、日志分析的重要性以及开发环境与生产环境运行结果不一致... 目录前言1. Nginx部署后未按预期显示结果1.1 查看Nginx的启动情况1.2 解决启动失败的

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Security OAuth2 单点登录流程

单点登录(英语:Single sign-on,缩写为 SSO),又译为单一签入,一种对于许多相互关连,但是又是各自独立的软件系统,提供访问控制的属性。当拥有这项属性时,当用户登录时,就可以获取所有系统的访问权限,不用对每个单一系统都逐一登录。这项功能通常是以轻型目录访问协议(LDAP)来实现,在服务器上会将用户信息存储到LDAP数据库中。相同的,单一注销(single sign-off)就是指

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06