浅谈知识蒸馏(Knowledge Distillation)

2024-02-09 03:58

本文主要是介绍浅谈知识蒸馏(Knowledge Distillation),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

浅谈知识蒸馏(Knowledge Distillation)

前言:

在实验室做算法研究时,我们最看重的一般是模型精度,因为精度是我们模型有效性的最直接证明。而在公司做研发时,除了算法精度,我们还很关注模型的大小和内存占用。因为实验室模型一般运行在服务器上,很少有运算资源不足的情况,但是公司研发的算法功能最终都是要部署到实际的产品上的,像手机或者小型计算平台,其运算资源是很有限的。所以算法工程师在公司做预研时,算法建模一般都分两部分:先根据需求建模,并尽可能提高模型精度;然后进行模型压缩,在保证算法精度的情况下尽可能减少其参数量。

常用的模型压缩方法: 知识蒸馏、权重共享、模型剪枝、网络量化以及低秩分解。本文我们主要介绍知识蒸馏。
模型剪枝示意图:左边为训练好的大模型,通过剪枝删除掉一切权值趋于零的节点,达到缩减模型参数的效果上图为模型剪枝示意图:左边为训练好的大模型,通过剪枝删除掉一切权值趋于零的节点,达到缩减模型参数的效果

知识蒸馏(Knowledge Distillation,KD):

我一直觉得训练神经网络的过程很像求解线性方程组,用已知的数据及标签(对应求解方程中的xy点对)来拟合一批模型参数(对应方程组的系数矩阵)。一般来讲,在数据量有限的情况下,如果我们的模型过大,就很容易出现过拟合现象,此时我们需要缩减模型参数,或者添加正则项。

但在数据量足够的情况下,网络模型越复杂、参数量越大,训练出的模型性能会越好,而较小的网络却很难达到大网络那么好的效果。要让一个小网络达到和大模型相近的性能,我们就需要换一个思路,让大模型在训练过程中学习到的知识迁移到小模型中,而这个迁移的过程就叫做知识蒸馏(Knowledge Distillation,KD)

知识蒸馏的开山之作为大佬Hinton发表在NIPS2014文章《Distilling the Knowledge in a Neural Network》。其主要思想是:在给定输入的情况下训练迁移模型(Student Network),让其输出与原模型(Teacher Network)的输出一致,从而达到将原模型学习到的知识迁移至小网络的目标。
在这里插入图片描述
上图为知识蒸馏模型训练示意图:左侧为大参数量的原模型(Teacher Network),右侧为小网络(Student Network)

训练过程中,原模型(Teacher)输出 vi 与小模型(Student)输出 zi 之间的一致性约束是知识蒸馏的关键所在,即最小化下式:
在这里插入图片描述
对于输出一致性约束,常用的一般为各种距离度量、或者K-L散度等。在神经网络模型中,训练模型就是让模型的softmax输出与Ground Truth匹配;而知识蒸馏任务中,我们需要让Student网络与Teacher网络的的softmax输出尽可能匹配。

下式定义为普通的Softmax函数:
在这里插入图片描述
从上面softmax函数的定义式中我们不难看出,它先通过指数函数拉大输出节点之间的差异,然后通过归一化输出一个接近one-hot的向量(其中一个值很大,其他值接近于0)。对于普通的分类等任务,这样的操作没什么问题,但在知识蒸馏中,这种one-hot形式的输出对于知识的体现很有限,并不利于Student网络的学习(容易放大错误分类的概率,引入不必要的噪声)。所以我们通过引入一个温度参数T来将softmax输出的hard分布转化为soft。

加温度参数T后的softmax定义如下:
在这里插入图片描述
上述公式可以理解为:将网络的输出除以温度参数T后再做softmax,这样可以获得比较soft的输出向量:向量中每个值介于0~1之间,各个值之间的差异没有one-hot那么大。并且T的数值越大,分布越缓和。

训练过程中,模型总体的损失函数由两部分组成如下所示:
在这里插入图片描述
其中,Alpha和Beta为权重参数,Lsoft 为Distill Loss,保证Student网络输出与Teacher网络输出保持一致性,其定义如下:
在这里插入图片描述
其中,pj 和 qj 分别为Teacher网络和Student网络在温度T下的softmax输出向量的第j个值。

因为Teacher网络虽然已经经过了预训练,但其输出也会有一定的误差,为了降低将这些误差迁移到Student网络的可能性,在训练时还添加了Lhard :通过Ground-truth对Student网络的约束损失,定义如下:
在这里插入图片描述

其中,cj 为第j个类别的Ground-truth,qj 为Student网络softmax输出向量的第j个值。

这篇关于浅谈知识蒸馏(Knowledge Distillation)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/693050

相关文章

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物

sqlite3 相关知识

WAL 模式 VS 回滚模式 特性WAL 模式回滚模式(Rollback Journal)定义使用写前日志来记录变更。使用回滚日志来记录事务的所有修改。特点更高的并发性和性能;支持多读者和单写者。支持安全的事务回滚,但并发性较低。性能写入性能更好,尤其是读多写少的场景。写操作会造成较大的性能开销,尤其是在事务开始时。写入流程数据首先写入 WAL 文件,然后才从 WAL 刷新到主数据库。数据在开始

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

浅谈PHP5中垃圾回收算法(Garbage Collection)的演化

前言 PHP是一门托管型语言,在PHP编程中程序员不需要手工处理内存资源的分配与释放(使用C编写PHP或Zend扩展除外),这就意味着PHP本身实现了垃圾回收机制(Garbage Collection)。现在如果去PHP官方网站(php.net)可以看到,目前PHP5的两个分支版本PHP5.2和PHP5.3是分别更新的,这是因为许多项目仍然使用5.2版本的PHP,而5.3版本对5.2并不是完

【Python知识宝库】上下文管理器与with语句:资源管理的优雅方式

🎬 鸽芷咕:个人主页  🔥 个人专栏: 《C++干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 文章目录 前言一、什么是上下文管理器?二、上下文管理器的实现三、使用内置上下文管理器四、使用`contextlib`模块五、总结 前言 在Python编程中,资源管理是一个重要的主题,尤其是在处理文件、网络连接和数据库

dr 航迹推算 知识介绍

DR(Dead Reckoning)航迹推算是一种在航海、航空、车辆导航等领域中广泛使用的技术,用于估算物体的位置。DR航迹推算主要通过已知的初始位置和运动参数(如速度、方向)来预测物体的当前位置。以下是 DR 航迹推算的详细知识介绍: 1. 基本概念 Dead Reckoning(DR): 定义:通过利用已知的当前位置、速度、方向和时间间隔,计算物体在下一时刻的位置。应用:用于导航和定位,

【H2O2|全栈】Markdown | Md 笔记到底如何使用?【前端 · HTML前置知识】

Markdown的一些杂谈 目录 Markdown的一些杂谈 前言 准备工作 认识.Md文件 为什么使用Md? 怎么使用Md? ​编辑 怎么看别人给我的Md文件? Md文件命令 切换模式 粗体、倾斜、下划线、删除线和荧光标记 分级标题 水平线 引用 无序和有序列表 ​编辑 任务清单 插入链接和图片 内嵌代码和代码块 表格 公式 其他 源代码 预

图神经网络(2)预备知识

1. 图的基本概念         对于接触过数据结构和算法的读者来说,图并不是一个陌生的概念。一个图由一些顶点也称为节点和连接这些顶点的边组成。给定一个图G=(V,E),  其 中V={V1,V2,…,Vn}  是一个具有 n 个顶点的集合。 1.1邻接矩阵         我们用邻接矩阵A∈Rn×n表示顶点之间的连接关系。 如果顶点 vi和vj之间有连接,就表示(vi,vj)  组成了

JAVA初级掌握的J2SE知识(二)和Java核心的API

/** 这篇文章送给所有学习java的同学,请大家检验一下自己,不要自满,你们正在学习java的路上,你们要加油,蜕变是个痛苦的过程,忍受过后,才会蜕变! */ Java的核心API是非常庞大的,这给开发者来说带来了很大的方便,经常人有评论,java让程序员变傻。 但是一些内容我认为是必须掌握的,否则不可以熟练运用java,也不会使用就很难办了。 1、java.lang包下的80%以上的类