深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果

2023-12-30 16:48

本文主要是介绍深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.Pcdet-PointPillars原始模型结构

     网络部分包含4部分:
    (1)PillarVFE
    (2)PointPillarScatter
    (3)BaseBEVBackbone
    (4)AnchorHeadSingle

主要对BaseBEVBackbone部分剪枝,BaseBEVBackbone网络结构图如下:
在这里插入图片描述
具体如下:

  (backbone_2d): BaseBEVBackbone((blocks): ModuleList((0): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU())(1): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU()(13): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(14): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(15): ReLU()(16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(17): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(18): ReLU())(2): Sequential((0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)(1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)(2): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(3): ReLU()(4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(5): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(6): ReLU()(7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(8): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(9): ReLU()(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(11): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(12): ReLU()(13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(14): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(15): ReLU()(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(17): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(18): ReLU()))(deblocks): ModuleList((0): Sequential((0): ConvTranspose2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())(1): Sequential((0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())(2): Sequential((0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)(2): ReLU())))

2.剪枝

2.1稀疏训练
    对BN层的参数进行诱导,让大部分参数趋于零,降低剪枝对模型精度的影响

loss.backward()
updateBN(model)
optimizer.step()
def updateBN(model):s = 0.0001for m in model.modules():if isinstance(m, torch.nn.BatchNorm2d):m.weight.grad.data.add_(s*torch.sign(m.weight.data))  # L1

2.2对稀疏训练后的模型剪枝-Network_Slimming

(1)根据剪枝率(percent)计算阈值

    total = 0for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total)index = 0for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index:(index+size)] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)thre_index = int(total * 0.7)thre = y[thre_index]

(2)生成cfg_index(通道剪枝个数索引列表)与cfg_mask

    pruned = 0cfg_index = []cfg_mask = []for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.abs().clone()mask = weight_copy.cpu().gt(thre).float().cuda()#pdb.set_trace()pruned = pruned + mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask)m.bias.data.mul_(mask)cfg_index.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg_index.append('M')

(3)对不想剪枝的bn层,cfg_mask该bn层参数全部置1

例如:

cfg_mask[0][:]=1   ##对第一个bn不剪枝

注:1)应该有更好的办法,具体问题具体分析,现在只是实现了
       2)有很多层不能剪枝,请注意

(4)根据cfg_index构建剪枝后模型框架

newmodel = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set, cfg_index=cfg_index)
newmodel = newmodel.to(device='cuda:0')

注意,此处的build_network需要改写,我主要是剪枝BaseBEVBackbone,所以将此模块的每个卷积层的输入输出尺寸与cfg_index对应,如下:

class BaseBEVBackbone(nn.Module):def __init__(self, model_cfg, input_channels, cfg_index=None):super().__init__()self.model_cfg = model_cfgif self.model_cfg.get('LAYER_NUMS', None) is not None:assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)layer_nums = self.model_cfg.LAYER_NUMSlayer_strides = self.model_cfg.LAYER_STRIDESnum_filters = self.model_cfg.NUM_FILTERSelse:layer_nums = layer_strides = num_filters = []if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERSupsample_strides = self.model_cfg.UPSAMPLE_STRIDESelse:upsample_strides = num_upsample_filters = [] num_levels = len(layer_nums)c_in_list = [input_channels, *num_filters[:-1]]self.blocks = nn.ModuleList()self.deblocks = nn.ModuleList()if cfg_index is None:cfg_index = [64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 128, 128, 128]cfg=cfg_indexfor idx in range(num_levels): if idx == 0:cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(64,64, kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(3):if k ==0:cur_layers.extend([nn.Conv2d(64, cfg[k+1], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),nn.ReLU()])if k ==1:cur_layers.extend([nn.Conv2d(cfg[k+0], cfg[k+1], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),nn.ReLU()])if k ==2:cur_layers.extend([nn.Conv2d(cfg[k+0], 64, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),nn.ReLU()])elif idx ==1 :cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(64, cfg[4], kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(cfg[4], eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(5):if k ==4:cur_layers.extend([nn.Conv2d(cfg[k+4], 128, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(128, eps=1e-3, momentum=0.01),nn.ReLU()])else:cur_layers.extend([nn.Conv2d(cfg[k+4], cfg[k+5], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+5], eps=1e-3, momentum=0.01),nn.ReLU()])elif idx ==2 :cur_layers = [nn.ZeroPad2d(1),nn.Conv2d(128, cfg[10], kernel_size=3,stride=layer_strides[idx], padding=0, bias=False),nn.BatchNorm2d(cfg[10], eps=1e-3, momentum=0.01),nn.ReLU()]for k in range(5):if k==4:cur_layers.extend([nn.Conv2d(cfg[k+10], 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256, eps=1e-3, momentum=0.01),nn.ReLU()])else:cur_layers.extend([nn.Conv2d(cfg[k+10], cfg[k+11], kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(cfg[k+11], eps=1e-3, momentum=0.01),nn.ReLU()])self.blocks.append(nn.Sequential(*cur_layers))if len(upsample_strides) > 0:stride = upsample_strides[idx]if stride >= 1:self.deblocks.append(nn.Sequential(nn.ConvTranspose2d(num_filters[idx], num_upsample_filters[idx],upsample_strides[idx],stride=upsample_strides[idx], bias=False),nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),nn.ReLU()))else:stride = np.round(1 / stride).astype(np.int)self.deblocks.append(nn.Sequential(nn.Conv2d(num_filters[idx], num_upsample_filters[idx],stride,stride=stride, bias=False),nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),nn.ReLU()))c_in = sum(num_upsample_filters)if len(upsample_strides) > num_levels:self.deblocks.append(nn.Sequential(nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),nn.ReLU(),))self.num_bev_features = c_in########testdef forward(self, data_dict):"""Args:data_dict:spatial_featuresReturns:"""spatial_features = data_dict['spatial_features'] ups = []ret_dict = {}x = spatial_featuresfor i in range(len(self.blocks)):         x = self.blocks[i](x)            stride = int(spatial_features.shape[2] / x.shape[2])ret_dict['spatial_features_%dx' % stride] = xif len(self.deblocks) > 0:ups.append(self.deblocks[i](x))else:ups.append(x)if len(ups) > 1:x = torch.cat(ups, dim=1)elif len(ups) == 1:x = ups[0]if len(self.deblocks) > len(self.blocks):x = self.deblocks[-1](x)data_dict['spatial_features_2d'] = xreturn data_dict

(5)对conv层及bn层参数进行剪枝

    old_modules = list(model.modules())new_modules = list(newmodel.modules())layer_id_in_cfg = 0start_mask = torch.ones(64)end_mask = cfg_mask[layer_id_in_cfg]conv_count = 0bn_count = 0for layer_id in range(len(old_modules)):m0 = old_modules[layer_id]m1 = new_modules[layer_id]#print("old_modules  is: ", old_modules)if isinstance(m0, nn.BatchNorm2d):idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))if idx1.size == 1:idx1 = np.resize(idx1,(1,))if bn_count == 0 :# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.m1.weight.data = m0.weight.data.clone()m1.bias.data = m0.bias.data.clone()m1.running_mean = m0.running_mean.clone()m1.running_var = m0.running_var.clone()bn_count += 1layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):end_mask = cfg_mask[layer_id_in_cfg]else:bn_count += 1m1.weight.data = m0.weight.data[idx1.tolist()].clone()m1.bias.data = m0.bias.data[idx1.tolist()].clone()m1.running_mean = m0.running_mean[idx1.tolist()].clone()m1.running_var = m0.running_var[idx1.tolist()].clone()layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg]elif isinstance(m0, nn.Conv2d):if conv_count == 0:m1.weight.data = m0.weight.data.clone()conv_count += 1continueif layer_id == (len(old_modules)-1):m1.weight.data = m0.weight.data.clone()continueif isinstance(old_modules[layer_id+1], nn.BatchNorm2d):# This convers the convolutions in the residual block.# The convolutions are either after the channel selection layer or after the batch normalization layer.conv_count += 1idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))if idx0.size == 1:idx0 = np.resize(idx0, (1,))if idx1.size == 1:idx1 = np.resize(idx1, (1,))w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()# If the current convolution is not the last convolution in the residual block, then we can change the # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.w1 = w1[idx1.tolist(), :, :, :].clone()m1.weight.data = w1.clone()torch.save({'cfg': cfg, 'model_state': newmodel.state_dict()}, os.path.join('./', 'pruned_90.pth'))

(6)使用新模型,并load参数,测试效果

3.测试步骤

首先数据预处理:

python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml

然后:
(1).环境位置:阵列g03,zxw_compression容器,/data/OpenPCDet-master/tools
(2).运行命令:
    测试:

CUDA_VISIBLE_DEVICES=2 python test.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 1 --ckpt checkpoint_epoch_90.pth

    训练:

python train.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 8 --epochs 100

(3)每次切换一个OpenPCDet-master,需要运行命令

python setup.py develop

4.剪枝结果

在这里插入图片描述
在这里插入图片描述

5.总结

(1)不是所有的模型都能用剪枝来加速,很多网络层没有BN,或者可剪枝的层数过少,增速不明显;
(2)从结果来看,PCDET中的pointpillars网络部分耗时很少,主要时间浪费在后处理中的NMS模块,还未深入研究此模块耗时原因;
(3)剪枝Backbone2d层会减少后处理速度,有待研究;
(4)对于有大量conv2d+bn组合的网络结构,网络层数较多的,例如resnet152,可以采用剪枝来加速。注:本文所说的剪枝,指的是根据bn参数,对通道剪枝,不涉及其他剪枝

这篇关于深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/553455

相关文章

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

Spring Security中用户名和密码的验证完整流程

《SpringSecurity中用户名和密码的验证完整流程》本文给大家介绍SpringSecurity中用户名和密码的验证完整流程,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定... 首先创建了一个UsernamePasswordAuthenticationTChina编程oken对象,这是S

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

SpringBoot开发中十大常见陷阱深度解析与避坑指南

《SpringBoot开发中十大常见陷阱深度解析与避坑指南》在SpringBoot的开发过程中,即使是经验丰富的开发者也难免会遇到各种棘手的问题,本文将针对SpringBoot开发中十大常见的“坑... 目录引言一、配置总出错?是不是同时用了.properties和.yml?二、换个位置配置就失效?搞清楚加

Android ViewBinding使用流程

《AndroidViewBinding使用流程》AndroidViewBinding是Jetpack组件,替代findViewById,提供类型安全、空安全和编译时检查,代码简洁且性能优化,相比Da... 目录一、核心概念二、ViewBinding优点三、使用流程1. 启用 ViewBinding (模块级

SpringBoot整合Flowable实现工作流的详细流程

《SpringBoot整合Flowable实现工作流的详细流程》Flowable是一个使用Java编写的轻量级业务流程引擎,Flowable流程引擎可用于部署BPMN2.0流程定义,创建这些流程定义的... 目录1、流程引擎介绍2、创建项目3、画流程图4、开发接口4.1 Java 类梳理4.2 查看流程图4

java Long 与long之间的转换流程

《javaLong与long之间的转换流程》Long类提供了一些方法,用于在long和其他数据类型(如String)之间进行转换,本文将详细介绍如何在Java中实现Long和long之间的转换,感... 目录概述流程步骤1:将long转换为Long对象步骤2:将Longhttp://www.cppcns.c