变分互信息蒸馏(Variational mutual information KD)

2024-01-31 11:30

本文主要是介绍变分互信息蒸馏(Variational mutual information KD),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

原文标题是Variational Information Distillation for Knowledge Transfer,是CVPR2019的录用paper。

VID方法

在这里插入图片描述
思路比较简单,就是利用互信息(mutual information,MI)的角度,增加teacher网络与student网络中间层特征的MI,motivation是因为MI可以表示两个变量的依赖程度,MI越大,表明两者的输出越相关。
首先定义输入数据 x ∼ p ( x ) \bm{x}\sim p(\bm{x}) xp(x),给定一个样本 x \bm{x} x,得到关于teacher和student输出的 K K K个对集合 R = { ( t ( k ) , s ( k ) ) } k = 1 K \mathcal{R}=\{(\bm{t}^{(k)},\bm{s}^{(k)})\}_{k=1}^{K} R={(t(k),s(k))}k=1K, K K K表示选择的层数。变量对的MI被定义为 I ( t ; s ) = H ( t ) − H ( t ∣ s ) = − E t [ log ⁡ p ( t ) ] + E t , s [ log ⁡ p ( t ∣ s ) ] I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =-\mathbb{E}_{\bm{t}}[\log p(\bm{t})]+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})] I(t;s)=H(t)H(ts)=Et[logp(t)]+Et,s[logp(ts)]
之后可以设计如下的loss函数来增大teacher和student之间的输出特征的互信息:
L = L S − ∑ k = 1 K λ k I ( t ( k ) , s ( k ) ) \mathcal{L}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}I(\bm{t}^{(k)},\bm{s}^{(k)}) L=LSk=1KλkI(t(k),s(k))
其中 L S \mathcal{L_{S}} LS表示task-specific的误差, λ k \lambda_{k} λk是超参数用于平衡误差。因为精确的计算MI是困难的,这里采用了变分下界(variational lower bound)的trick,采用variational的思想使用一个variational分布 q ( t ∣ s ) q(\bm{t}|\bm{s}) q(ts)去近似真实分布 p ( t ∣ s ) p(\bm{t}|\bm{s}) p(ts)
Note that variational的思想就是针对某个分布很难求解的时候,采用另外一个分布来近似这个分布的做法,并使用变分信息最大化 (论文:The IM algorithm: A variational approach to information maximization) 的方法求解变分下界(variational low bound),这方法也被用在InfoGAN中。
I ( t ; s ) = H ( t ) − H ( t ∣ s ) = H ( t ) + E t , s [ log ⁡ p ( t ∣ s ) ] = H ( t ) + E t , s [ log ⁡ q ( t ∣ s ) ] + E s [ D K L ( p ( t ∣ s ) ∣ ∣ q ( t ∣ s ) ) ] ≥ H ( t ) + E t , s [ log ⁡ q ( t ∣ s ) ] I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]\\ \geq H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})] I(t;s)=H(t)H(ts)=H(t)+Et,s[logp(ts)]=H(t)+Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]H(t)+Et,s[logq(ts)]
E t , s [ log ⁡ p ( t ∣ s ) ] = E t , s [ log ⁡ q ( t ∣ s ) ] + E s [ D K L ( p ( t ∣ s ) ∣ ∣ q ( t ∣ s ) ) ] \mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]=\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))] Et,s[logp(ts)]=Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]这个关系是由变分信息最大化中得到的,真实分布 log ⁡ p ( t ∣ s ) \log p(\bm{t|s}) logp(ts)的期望等于变分分布 E t , s [ log ⁡ q ( t ∣ s ) ] \mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})] Et,s[logq(ts)]的期望+两分布的KL散度期望。因为KL散度的值是恒大于0的,所以得到变分下界。进一步可以得到如下的误差函数:
L ~ = L S − ∑ k = 1 K λ k E t ( k ) , s ( k ) [ log ⁡ q ( t ( k ) ∣ s ( k ) ) ] \mathcal{\tilde{L}}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}\mathbb{E}_{\bm{t^{(k)},s^{(k)}}}[\log q(\bm{t^{(k)}|s^{(k)}})] L~=LSk=1KλkEt(k),s(k)[logq(t(k)s(k))]
H ( t ) H(\bm{t}) H(t)由于和待优化的student参数无关,所以是常数。联合的训练学生网络利用target task和最大化条件似然去拟合teacher激活值。

作者采用高斯分布来实例化变分分布,这里的采用heteroscedastic的均值 μ ( ⋅ ) \bm{\mu}(\cdot) μ(),即 μ ( ⋅ ) \bm{\mu}(\cdot) μ()是关于student输出的函数;同时采用homoscedastic的方差 σ \bm{\sigma} σ,即不是关于student输出的函数,作者尝试采用heteroscedastic的均值 σ ( ⋅ ) \bm{\sigma}(\cdot) σ(),但是容易训练不稳定且提升不大。 μ ( ⋅ ) \bm{\mu}(\cdot) μ()其实就是相当于在feature KD时teacher与student之间的回归器,包含卷积等操作。
− log ⁡ q ( t ∣ s ) = − ∑ c = 1 C ∑ h = 1 H ∑ w = 1 W log ⁡ q ( t c , h , w ∣ s ) = ∑ c = 1 C ∑ h = 1 H ∑ w = 1 W log ⁡ σ c + ( t c , h , w − μ c , h , w ( s ) ) 2 2 σ c 2 + c o n s t a n t -\log q(\bm{t|s})=-\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log q(t_{c,h,w}|\bm{s})\\ =\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log \sigma_{c}+\frac{(t_{c,h,w}-\mu_{c,h,w}(\bm{s}))^{2}}{2\sigma_{c}^{2}}+\rm{constant} logq(ts)=c=1Ch=1Hw=1Wlogq(tc,h,ws)=c=1Ch=1Hw=1Wlogσc+2σc2(tc,h,wμc,h,w(s))2+constant
σ c = log ⁡ ( 1 + e x p ( α c ) ) \sigma_{c}=\log(1+exp(\alpha_{c})) σc=log(1+exp(αc)) α c \alpha_{c} αc是一个可学习的参数。
对于logit层, − log ⁡ q ( t ∣ s ) = − ∑ n = 1 N log ⁡ q ( t n ∣ s ) = ∑ n = 1 N log ⁡ σ n + ( t n − μ n ( s ) ) 2 2 σ n 2 + c o n s t a n t -\log q(\bm{t|s})=-\sum_{n=1}^{N}\log q(t_{n}|\bm{s})\\ =\sum_{n=1}^{N}\log \sigma_{n}+\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2\sigma_{n}^{2}}+\rm{constant} logq(ts)=n=1Nlogq(tns)=n=1Nlogσn+2σn2(tnμn(s))2+constant
这里 μ ( ⋅ ) \bm{\mu}(\cdot) μ()是一个线性的变换矩阵。

与MSE的区别

作者认为当前基于MSE的方法是该方法在方差相同时的特例,即为:
− log ⁡ q ( t ∣ s ) = ∑ n = 1 N ( t n − μ n ( s ) ) 2 2 + c o n s t a n t -\log q(\bm{t|s})=\sum_{n=1}^{N}\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2}+\rm{constant} logq(ts)=n=1N2(tnμn(s))2+constant
VID比MSE的好处为建模了不同维度的方差,使得更加灵活的方式来避免一些model capacity用来到一些无用的信息。MSE采用一样的方差会高度限制student,如果teacher的无用信息也同样的地位拟合,会造成过拟合问题,浪费掉了student的网络capacity。

这篇关于变分互信息蒸馏(Variational mutual information KD)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/663605

相关文章

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention 文章目录 一、基本原理1. 变分模态分解(VMD)2. 双向时域卷积(BiTCN)3. 双向门控单元(BiGRU)4. 注意力机制(Attention)总结流程 二、实验结果三、核心代码四、代码获取五、总结 时序预测|变分模态分解-双向时域卷积

计算机视觉中,什么是上下文信息(contextual information)?

在计算机视觉中,上下文信息(contextual information)是指一个像素或一个小区域周围的环境或背景信息,它帮助模型理解图像中对象的相对位置、大小、形状,以及与其他对象的关系。上下文信息在图像中提供了全局的语义和结构线索,使模型不仅依赖局部细节,而且能够考虑整个场景或图像的大局。 上下文信息的具体含义 局部与全局信息的结合: 局部信息:这是指某个小区域或某个像素点的特征。通过小

【读论文】MUTUAL-CHANNEL LOSS

论文题目:《The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification》 链接:https://arxiv.org/abs/2002.04264 来源:IEEE TIP2020 细粒度分类的主要思想是找出各个子类间的可区分特征,因此文章指出要尽早在通道上进行钻研,而不是从合并

ML17_变分推断Variational Inference

1. KL散度 KL散度(Kullback-Leibler divergence),也称为相对熵(relative entropy),是由Solomon Kullback和Richard Leibler在1951年引入的一种衡量两个概率分布之间差异的方法。KL散度不是一种距离度量,因为它不满足距离度量的对称性和三角不等式的要求。但是,它仍然被广泛用于量化两个概率分布之间的“接近程度”。 在

蒸馏之道:如何提取白酒中的精华?

在白酒的酿造过程中,蒸馏是一道至关重要的工序,它如同一位技艺精细的炼金术士,将原料中的精华提炼出来,凝聚成滴滴琼浆。今天,我们就来探寻这蒸馏之道,看看豪迈白酒(HOMANLISM)是如何提取白酒中的精华的。 一、蒸馏:白酒酿造的魔法时刻 蒸馏,是白酒酿造中的关键环节。在这个过程中,酿酒师们通过巧妙的操作和精细的技艺,将原料中的酒精和风味物质提取出来,为后续的陈酿和勾调提供基础。蒸馏不仅要求

平均场变分推断:以混合高斯模型为例

文章目录 一、贝叶斯推断的工作流二、一个业务例子三、变分推断四、平均场理论五、业务CASE的平均场变分推断求解六、代码实现 一、贝叶斯推断的工作流 在贝叶斯推断方法中,工作流可以总结为: 根据观察者的知识,做出合理假设,假设数据是如何被生成的将数据的生成模型转化为数学模型根据数据通过数学方法,求解模型参数对新的数据做出预测 在整个pipeline中,第1点数据的生成过程

深度学习-生成模型:Generation(Tranform Vector To Object with RNN)【PixelRNN、VAE(变分自编码器)、GAN(生成对抗网络)】

深度学习-生成模型:Generation(Tranform Vector To Object with RNN)【PixelRNN、VAE(变分自编码器)、GAN(生成对抗网络)】 一、Generator的分类二、Native Generator (AutoEncoder's Decoder)三、PixelRNN1、生成句子序列2、生成图片3、生成音频:WaveNet4、生成视频:Video

Hive 2.3.0 MetaException(message:Version information not found in metastore. )

使用Hive 2.3.0 配置远程模式(Remote)时,执行hive --service metastore命令时出现MetaException(message:Version information not found in metastore. )错误。 解决办法: The necessary tables required for the metastore are missing i

AWS S3对象无法下载——This XML file does not appear to have any style information associated with it

最近,需要从AWS S3上下载渲染后的图片,遇到了如下问题: This XML file does not appear to have any style information associated with it. The document tree is shown below. <Error><Code>AccessDenied</Code><Message>Acce

大语言模型数据增强与模型蒸馏解决方案

背景 在人工智能和自然语言处理领域,大语言模型通过训练数百亿甚至上千亿参数,实现了出色的文本生成、翻译、总结等任务。然而,这些模型的训练和推理过程需要大量的计算资源,使得它们的实际开发应用成本非常高;其次,大规模语言模型的高能耗和长响应时间问题也限制了其在资源有限场景中的使用。模型蒸馏将大模型“知识”迁移到较小模型。通过模型蒸馏,可以在保留大部分性能的前提下,显著减少模型的规模,从而降低计算资源