变分互信息蒸馏(Variational mutual information KD)

2024-01-31 11:30

本文主要是介绍变分互信息蒸馏(Variational mutual information KD),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

原文标题是Variational Information Distillation for Knowledge Transfer,是CVPR2019的录用paper。

VID方法

在这里插入图片描述
思路比较简单,就是利用互信息(mutual information,MI)的角度,增加teacher网络与student网络中间层特征的MI,motivation是因为MI可以表示两个变量的依赖程度,MI越大,表明两者的输出越相关。
首先定义输入数据 x ∼ p ( x ) \bm{x}\sim p(\bm{x}) xp(x),给定一个样本 x \bm{x} x,得到关于teacher和student输出的 K K K个对集合 R = { ( t ( k ) , s ( k ) ) } k = 1 K \mathcal{R}=\{(\bm{t}^{(k)},\bm{s}^{(k)})\}_{k=1}^{K} R={(t(k),s(k))}k=1K, K K K表示选择的层数。变量对的MI被定义为 I ( t ; s ) = H ( t ) − H ( t ∣ s ) = − E t [ log ⁡ p ( t ) ] + E t , s [ log ⁡ p ( t ∣ s ) ] I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =-\mathbb{E}_{\bm{t}}[\log p(\bm{t})]+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})] I(t;s)=H(t)H(ts)=Et[logp(t)]+Et,s[logp(ts)]
之后可以设计如下的loss函数来增大teacher和student之间的输出特征的互信息:
L = L S − ∑ k = 1 K λ k I ( t ( k ) , s ( k ) ) \mathcal{L}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}I(\bm{t}^{(k)},\bm{s}^{(k)}) L=LSk=1KλkI(t(k),s(k))
其中 L S \mathcal{L_{S}} LS表示task-specific的误差, λ k \lambda_{k} λk是超参数用于平衡误差。因为精确的计算MI是困难的,这里采用了变分下界(variational lower bound)的trick,采用variational的思想使用一个variational分布 q ( t ∣ s ) q(\bm{t}|\bm{s}) q(ts)去近似真实分布 p ( t ∣ s ) p(\bm{t}|\bm{s}) p(ts)
Note that variational的思想就是针对某个分布很难求解的时候,采用另外一个分布来近似这个分布的做法,并使用变分信息最大化 (论文:The IM algorithm: A variational approach to information maximization) 的方法求解变分下界(variational low bound),这方法也被用在InfoGAN中。
I ( t ; s ) = H ( t ) − H ( t ∣ s ) = H ( t ) + E t , s [ log ⁡ p ( t ∣ s ) ] = H ( t ) + E t , s [ log ⁡ q ( t ∣ s ) ] + E s [ D K L ( p ( t ∣ s ) ∣ ∣ q ( t ∣ s ) ) ] ≥ H ( t ) + E t , s [ log ⁡ q ( t ∣ s ) ] I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]\\ \geq H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})] I(t;s)=H(t)H(ts)=H(t)+Et,s[logp(ts)]=H(t)+Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]H(t)+Et,s[logq(ts)]
E t , s [ log ⁡ p ( t ∣ s ) ] = E t , s [ log ⁡ q ( t ∣ s ) ] + E s [ D K L ( p ( t ∣ s ) ∣ ∣ q ( t ∣ s ) ) ] \mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]=\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))] Et,s[logp(ts)]=Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]这个关系是由变分信息最大化中得到的,真实分布 log ⁡ p ( t ∣ s ) \log p(\bm{t|s}) logp(ts)的期望等于变分分布 E t , s [ log ⁡ q ( t ∣ s ) ] \mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})] Et,s[logq(ts)]的期望+两分布的KL散度期望。因为KL散度的值是恒大于0的,所以得到变分下界。进一步可以得到如下的误差函数:
L ~ = L S − ∑ k = 1 K λ k E t ( k ) , s ( k ) [ log ⁡ q ( t ( k ) ∣ s ( k ) ) ] \mathcal{\tilde{L}}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}\mathbb{E}_{\bm{t^{(k)},s^{(k)}}}[\log q(\bm{t^{(k)}|s^{(k)}})] L~=LSk=1KλkEt(k),s(k)[logq(t(k)s(k))]
H ( t ) H(\bm{t}) H(t)由于和待优化的student参数无关,所以是常数。联合的训练学生网络利用target task和最大化条件似然去拟合teacher激活值。

作者采用高斯分布来实例化变分分布,这里的采用heteroscedastic的均值 μ ( ⋅ ) \bm{\mu}(\cdot) μ(),即 μ ( ⋅ ) \bm{\mu}(\cdot) μ()是关于student输出的函数;同时采用homoscedastic的方差 σ \bm{\sigma} σ,即不是关于student输出的函数,作者尝试采用heteroscedastic的均值 σ ( ⋅ ) \bm{\sigma}(\cdot) σ(),但是容易训练不稳定且提升不大。 μ ( ⋅ ) \bm{\mu}(\cdot) μ()其实就是相当于在feature KD时teacher与student之间的回归器,包含卷积等操作。
− log ⁡ q ( t ∣ s ) = − ∑ c = 1 C ∑ h = 1 H ∑ w = 1 W log ⁡ q ( t c , h , w ∣ s ) = ∑ c = 1 C ∑ h = 1 H ∑ w = 1 W log ⁡ σ c + ( t c , h , w − μ c , h , w ( s ) ) 2 2 σ c 2 + c o n s t a n t -\log q(\bm{t|s})=-\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log q(t_{c,h,w}|\bm{s})\\ =\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log \sigma_{c}+\frac{(t_{c,h,w}-\mu_{c,h,w}(\bm{s}))^{2}}{2\sigma_{c}^{2}}+\rm{constant} logq(ts)=c=1Ch=1Hw=1Wlogq(tc,h,ws)=c=1Ch=1Hw=1Wlogσc+2σc2(tc,h,wμc,h,w(s))2+constant
σ c = log ⁡ ( 1 + e x p ( α c ) ) \sigma_{c}=\log(1+exp(\alpha_{c})) σc=log(1+exp(αc)) α c \alpha_{c} αc是一个可学习的参数。
对于logit层, − log ⁡ q ( t ∣ s ) = − ∑ n = 1 N log ⁡ q ( t n ∣ s ) = ∑ n = 1 N log ⁡ σ n + ( t n − μ n ( s ) ) 2 2 σ n 2 + c o n s t a n t -\log q(\bm{t|s})=-\sum_{n=1}^{N}\log q(t_{n}|\bm{s})\\ =\sum_{n=1}^{N}\log \sigma_{n}+\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2\sigma_{n}^{2}}+\rm{constant} logq(ts)=n=1Nlogq(tns)=n=1Nlogσn+2σn2(tnμn(s))2+constant
这里 μ ( ⋅ ) \bm{\mu}(\cdot) μ()是一个线性的变换矩阵。

与MSE的区别

作者认为当前基于MSE的方法是该方法在方差相同时的特例,即为:
− log ⁡ q ( t ∣ s ) = ∑ n = 1 N ( t n − μ n ( s ) ) 2 2 + c o n s t a n t -\log q(\bm{t|s})=\sum_{n=1}^{N}\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2}+\rm{constant} logq(ts)=n=1N2(tnμn(s))2+constant
VID比MSE的好处为建模了不同维度的方差,使得更加灵活的方式来避免一些model capacity用来到一些无用的信息。MSE采用一样的方差会高度限制student,如果teacher的无用信息也同样的地位拟合,会造成过拟合问题,浪费掉了student的网络capacity。

这篇关于变分互信息蒸馏(Variational mutual information KD)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/663605

相关文章

20240621日志:大模型压缩-从闭源大模型蒸馏

目录 1. 核心内容2. 方法2.1 先验估计2.2 后验估计2.3 目标函数 3. 交叉熵损失函数与Kullback-Leibler(KL)损失函数 location:beijing 涉及知识:大模型压缩、知识蒸馏 Fig. 1 大模型压缩-知识蒸馏 1. 核心内容 本文提出在一个贝叶斯估计框架内估计闭源语言模型的输出分布,包括先验估计和后验估计。先验估计的目的是通

【C++PCL】点云处理Kd-tree原理

作者:迅卓科技 简介:本人从事过多项点云项目,并且负责的项目均已得到好评! 公众号:迅卓科技,一个可以让您可以学习点云的好地方 重点:每个模块都有参数如何调试的讲解,即调试某个参数对结果的影响是什么,大家有问题可以评论哈,如果文章有错误的地方,欢迎来指出错误的地方。 目录         1.原理介绍 1.原理介绍         kd-tree是散乱点云的一种储存结构,它是一种

Autoencoder(AE)、Variational Autoencoder(VAE)和Diffusion Models(DM)了解

Autoencoder (AE) 工作原理: Autoencoder就像一个数据压缩机器。它由两部分组成: 编码器:将输入数据压缩成一个小小的代码。解码器:将这个小代码还原成尽可能接近原始输入的数据。 优点和应用: 简单易懂:用于学习数据的特征和去除噪声。应用场景:例如可以用来缩小图像的大小但保留关键特征,或者去除文本数据中的错误。 挑战: 数据损坏:如果输入数据太乱,编码器可能无法有

论文学习 Learning Robust Representations via Multi-View Information Bottleneck

Code available at https://github.com/mfederici/Multi-View-Information-Bottleneck 摘要:信息瓶颈原理为表示学习提供了一种信息论方法,通过训练编码器保留与预测标签相关的所有信息,同时最小化表示中其他多余信息的数量。然而,最初的公式需要标记数据来识别多余的信息。在这项工作中,我们将这种能力扩展到多视图无监督设置,其中提供

KD-TREE 算法原理

KD-TREE 算法原理 http://www.oneie.com/index.php/qyjs/47-txcl/1532-kd-tree   本文介绍一种用于高维空间中的快速最近邻和近似最近邻查找技术——Kd- Tree(Kd树)。Kd-Tree,即K-dimensional tree,是一种高维索引树形数据结构,常用于在大规模的高维数据空间进行最近邻查找(Nearest Neighbor

KD Tree

转载地址:http://www.cnblogs.com/slysky/archive/2011/11/08/2241247.html KD Tree Kd-树 其实是K-dimension tree的缩写,是对数据点在k维空间中划分的一种数据结构。其实,Kd-树是一种平衡二叉树。 举一示例: 假设有六个二维数据点 = {(2,3),(5,4),(9,6),(4,7),(8,1)

四叉树和KD树

1. 简介 四叉树和KD树都是用于空间数据索引和检索的树状数据结构。它们通过将空间递归地划分为更小的区域,并存储每个区域内的点,来实现快速搜索和范围查询。 2. 四叉树 2.1 定义 四叉树是一种树状数据结构,它将二维空间递归地划分为四个相等的子区域,直到每个子区域只包含一个点或为空。每个节点代表一个矩形区域,并存储该区域内的所有点。 2.2 构建 构建四叉树的过程如下: 将整个空间

CVPR2024知识蒸馏Distillation论文49篇速通

Paper1 3D Paintbrush: Local Stylization of 3D Shapes with Cascaded Score Distillation 摘要小结: 我们介绍了3DPaintbrush技术,这是一种通过文本描述自动对网格上的局部语义区域进行纹理贴图的方法。我们的方法直接在网格上操作,生成的纹理图能够无缝集成到标准的图形管线中。我们选择同时生成一个定位图(指定编辑

量化、剪枝、蒸馏,这些大模型黑话到底说了些啥?

扎克伯格说,Llama3-8B还是太大了,不适合放到手机中,有什么办法? 量化、剪枝、蒸馏,如果你经常关注大语言模型,一定会看到这几个词,单看这几个字,我们很难理解它们都干了些什么,但是这几个词对于现阶段的大语言模型发展特别重要。这篇文章就带大家来认识认识它们,理解其中的原理。 模型压缩 量化、剪枝、蒸馏,其实是通用的神经网络模型压缩技术,不是大语言模型专有的。 模型压缩的意义 通过压缩

php: /usr/local/lib/libxml2.so.2: no version information available (required by php)

Linux下执行php *.php报php: /usr/local/lib/libxml2.so.2: no version information available (required by php)这个错误 解决办法: 把/usr/local/lib/libxml2.so.2这个文件删除就可以了,这是linux版本混乱的原因~~