tensorflow 网络修剪 剪枝操作

2024-04-03 21:18

本文主要是介绍tensorflow 网络修剪 剪枝操作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景知识

模型剪枝(Model Pruning)是一种模型压缩方法,对深度神经网络的稠密连接引入稀疏性,通过将“不重要”的权值直接置零来减少非零权值数量,其历史可追溯到上世纪 90 年代初。

在 Optimal Brain Damage【2】中,使用对角 Hessian 逼近计算每个权值的重要性,重要性低的权值被置零,然后重新训练网络。

在 Optimal Brain Surgeon【3】中,使用逆 Hessian 矩阵计算每个权值的重要性,重要性低的权值被置零,剩下的权值使用二阶泰勒逼近的 loss 增量更新。

最近比较流行基于幅度的权值剪枝方法【4】,该方法将权值取绝对值,与设定的 threshhold 值进行比较,低于门限的权值被置零。基于幅度的权值剪枝算法计算高效,可以应用到大部分模型和数据集。TensorFlow 也使用了基于幅度的权值剪枝算法。


TF 代码实现

TensorFlow 代码目录 tensorflow/contrib/model_pruning/ 提供了对 TensorFlow  框架的扩展,可在模型训练时实现剪枝。

对每个被选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。

研究发现稀疏度不宜从一开始就设置最大,这样容易将重要的权值剪掉造成无法挽回的准确率损失,更好的方法是渐进稀疏度,从初始稀疏度 (一般为 0 )开始,逐步增大到最终稀疏度 ,这期间二进制掩模变量 mask 经历了 n 次更新,每次更新时的门限由当时的稀疏度决定,稀疏度由如下公式计算得到:

随着训练过程,逐步提高稀疏度,直到达到期望的稀疏度 为止。

下图很直观地反映了渐进提高稀疏度的过程。

初始时刻,稀疏度提升较快,而越到后面,稀疏度提升速度会逐渐放缓,这个比较符合直觉,因为初始时有大量冗余的权值,而越到后面保留的权值数量越少,不能再“大刀阔斧”地修剪,而需要更谨慎些,避免“误伤无辜”。

下面 TensorFlow 代码创建了带有 mask 变量的 graph:

from tensorflow.contrib.model_pruning.python import pruningwith tf.variable_scope('conv1') as scope:# 创建权值 variablekernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)# 创建 conv2d op,权值 variable 增加 maskconv = tf.nn.conv2d(images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')

下面代码给出了带剪枝的模型训练代码结构:

from tensorflow.contrib.model_pruning.python import pruning# 命令行参数解析pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)# 创建剪枝对象pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)# 使用剪枝对象向训练图增加更新 mask 的运算符# 当且仅当训练步骤位于 [begin_pruning_step, end_pruning_step] 之间时,# conditional_mask_update_op 才会更新 maskmask_update_op = pruning_obj.conditional_mask_update_op()# 使用剪枝对象写入 summaries,用于跟踪每层权值 sparsity 变化pruning_obj.add_pruning_summaries()with tf.train.MonitoredTrainingSession() as mon_sess:while not mon_sess.should_stop():mon_sess.run(train_op)# 更新 maskmon_sess.run(mask_update_op)	

其中 FLAGS.pruning_hparams 为一组逗号分隔的键值对,取值如下表所示:

超参名类型默认值说明
begin_pruning_stepinteger0开始剪枝的全局 step
end_pruning_stepinteger-1结束剪枝的全局 step,默认为 -1 标识剪枝一直持续到训练结束
do_not_prunelist of strings[""]一组层名,标记哪些层不做剪枝
threshold_decayfloat0.9衰减因子,用于门限衰减
pruning_frequencyinteger10mask 更新的频率,计数单位为全局 step 数
initial_sparsityfloat0.0初始稀疏度值
target_sparsityfloat0.5目标稀疏度值
sparsity_function_begin_stepinteger0渐进稀疏度函数开始时刻
sparsity_function_end_stepinteger100渐进稀疏度函数结束时刻
sparsity_function_exponentfloat3.0

指数项,=1 则为线性增长,>1 则初始快后续慢

转者注: 详细代码可参照https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10    (例子)

           和 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/model_pruning/python/pruning.py  (pruning源码)                                                                                                                                                                                                                

实践

TensorFlow model pruning 自带 CIFAR10 例程,实现了一个稀疏 CNN 模型,其中卷积层和 local 层的权值均做了稀疏化。

(1) 准备 TensorFlow r1.7 环境

硬件环境:GTX 1080

软件环境:CUDA 9.0 + cuDNN 7, Bazel 0.11.1

git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow/
git checkout r1.7

(2) 编译、运行 tensorflow/contrib/model_pruning/

cd tensorflow/contrib/model_pruning/
bazel build -c opt examples/cifar10:cifar10_{train,val}
cd ../../
bazel-bin/contrib/model_pruning/examples/cifar10/cifar10_train -prune_hparams=name=cifar10_pruning,begin_pruning_step=10000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000

(3) 查看训练过程

运行 TensorBoard:

tensorboard --logdir /tmp/cifar10_train/

打开浏览器,输入 localhost:6006

可以看到随着训练步骤增加,conv1 和 conv2 的 sparsity 在不断增长。总的 loss 变化如下图所示:

(4) 查看计算图

切换到 GRAPHS 页面,双击 conv2 节点,可以看到在原有计算图基础上新增了 mask 和 threshold 节点用来做 model pruning。

(5) 模型评估

利用以下命令对训练模型进行评估:

bazel-bin/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval

训练 15 万次迭代的结果(仅供参考)

SparsityAccuracy after 150K steps
0%86%
50%86%
75%
90%
95%77%

论文【1】中一些结论

随着稀疏度提高,模型质量逐渐下降,其表现为分类准确率降低。下表为 InceptionV3 模型不同稀疏度的情况【1】:

从表中看到,50% 系数模型和基准模型(0% 稀疏度)表现一致,而 87.5% 稀疏度模型的 top-5 准确率相比基准模型只有 2% 降低,但模型非零权值数量减少为原来 1/8。

我们前面文章《用于移动和嵌入式视觉应用的 MobileNets》介绍过轻量 CNN 模型 MobileNet,是一类特别为移动视觉应用设计的高效卷积神经网络。MobileNet 基于 depthwise separable 卷积,将通道内滤波通道间线性组合分解为两个独立步骤,显著减少了参数数量。MobileNet 网络架构包括一个标准卷积层用于处理输入图片,一大堆 depthwise separable conv,最后为 average pooling 和全连接层。 width multiplier 是 MobileNet 的一个调节参数,能实现模型准确率和模型权值数量、计算量的 trade-off。

我们既可以通过设置更小的 width multiplier 实现尺寸更小的模型(准确率会降低),也可以通过对原始 MobileNet 做稀疏化得到尺寸更小的模型(准确率同样会降低),那么这两种方法哪种更有效呢?论文【1】 给出了结果:

基本结论为:大而稀疏的模型(large-sparse)表现优于小而稠密的模型(small-dense)。

例如,75% 稀疏度模型( 1.09 M 权值,top-1 accuracy 为 67.7%)优于稠密 0.5 MobileNet( 1.32 M 权值,top-1 accuracy 为 63.7%)。

类似地, 90% 稀疏度模型(0.46 M 权值,top-1 accuracy 61.8%)优于稠密 0.25 MobileNet(0.46 M 权值,top-1 accuracy 50.6%)。

通过这些结论,为轻量级模型设计提供了新的思路,同时也为专用硬件加速器设计提供了参考。


参考文献

【1】 To prune, or not to prune: exploring the efficacy of pruning for model compression, arXiv:1710.01878

【2】Yann LeCun et.al. Optimal brain damage. NIPS, 1990

【3】B.Hassibi et.al. Optimal brain surgeeon and general network pruning. ICNN, 1993

【4】Song Han et.al. Learning both weights and connections for efficient neural network. NIPS, 2015


这篇关于tensorflow 网络修剪 剪枝操作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/873998

相关文章

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

Mysql表的简单操作(基本技能)

《Mysql表的简单操作(基本技能)》在数据库中,表的操作主要包括表的创建、查看、修改、删除等,了解如何操作这些表是数据库管理和开发的基本技能,本文给大家介绍Mysql表的简单操作,感兴趣的朋友一起看... 目录3.1 创建表 3.2 查看表结构3.3 修改表3.4 实践案例:修改表在数据库中,表的操作主要

C# WinForms存储过程操作数据库的实例讲解

《C#WinForms存储过程操作数据库的实例讲解》:本文主要介绍C#WinForms存储过程操作数据库的实例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、存储过程基础二、C# 调用流程1. 数据库连接配置2. 执行存储过程(增删改)3. 查询数据三、事务处

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

Java利用JSONPath操作JSON数据的技术指南

《Java利用JSONPath操作JSON数据的技术指南》JSONPath是一种强大的工具,用于查询和操作JSON数据,类似于SQL的语法,它为处理复杂的JSON数据结构提供了简单且高效... 目录1、简述2、什么是 jsONPath?3、Java 示例3.1 基本查询3.2 过滤查询3.3 递归搜索3.4

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Python使用DrissionPage中ChromiumPage进行自动化网页操作

《Python使用DrissionPage中ChromiumPage进行自动化网页操作》DrissionPage作为一款轻量级且功能强大的浏览器自动化库,为开发者提供了丰富的功能支持,本文将使用Dri... 目录前言一、ChromiumPage基础操作1.初始化Drission 和 ChromiumPage

利用Go语言开发文件操作工具轻松处理所有文件

《利用Go语言开发文件操作工具轻松处理所有文件》在后端开发中,文件操作是一个非常常见但又容易出错的场景,本文小编要向大家介绍一个强大的Go语言文件操作工具库,它能帮你轻松处理各种文件操作场景... 目录为什么需要这个工具?核心功能详解1. 文件/目录存javascript在性检查2. 批量创建目录3. 文件

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用