本文主要是介绍Gumbel Softmax,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Argmax是不可求导的,Gumbel Softmax允许模型能从网络层的离散分布(比如类别分布categorical distribution)中稀疏采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数。
算法流程
- 对于某个网络层输出的 n \mathrm{n} n 维向量 v = [ v 1 , v 2 , … , v n ] v=\left[v_1, v_2, \ldots, v_n\right] v=[v1,v2,…,vn],生成 n \mathrm{n} n 个服从均匀分布 U ( 0 , 1 ) \mathrm{U}(0,1) U(0,1) 的独立样本 ϵ 1 , … , ϵ n \epsilon_1, \ldots, \epsilon_n ϵ1,…,ϵn
- 通过 G i = − log ( − log ( ϵ i ) ) G_i=-\log \left(-\log \left(\epsilon_i\right)\right) Gi=−log(−log(ϵi)) 计算得到 G i G_i Gi
- 对应相加得到新的值向量 v ′ = [ v 1 + G 1 , v 2 + G 2 , … , v n + G n ] v^{\prime}=\left[v_1+G_1, v_2+G_2, \ldots, v_n+G_n\right] v′=[v1+G1,v2+G2,…,vn+Gn]
- 通过softmax函数计算各个类别的概率大小,其中 τ \tau τ 是温度参数:
p τ ( v i ′ ) = e v i ′ / r ∑ j = 1 n e v j ′ / τ p_\tau\left(v_i^{\prime}\right)=\frac{e^{v_i^{\prime} / r}}{\sum_{j=1}^n e^{v_j^{\prime} / \tau}} pτ(vi′)=∑j=1nevj′/τevi′/r
Gumbel-Max Trick
Gumbel分布是专门用来建模从其他分布(比如高斯分布)采样出来的极值形成的分布,而我们这里“使用argmax挑出概率最大的那个类别索引”就属于取极值的操作,所以它属于Gumbel分布。
注意,极值的分布也是有规律的。
Gumbel-Max Trick的采样思想:先用均匀分布采样出一个随机值,然后把这个值带入到gumbel分布的CDF函数的逆函数得到采样值,即我们最终想要的类别索引。公示如下:
z = argmax i ( log ( p i ) + g i ) g i = − log ( − log ( u i ) ) , u i ∼ U ( 0 , 1 ) z=\operatorname{argmax}_i\left(\log \left(p_i\right)+g_i\right) \\ g_i=-\log \left(-\log \left(u_i\right)\right), u_i \sim U(0,1) z=argmaxi(log(pi)+gi)gi=−log(−log(ui)),ui∼U(0,1)
上式使用了重参数技巧把采样过程分成了确定性的部分和随机性的部分,我们会计算所有类别的log分布概率(确定性的部分),然后加上一些噪音(随机性的部分),这里噪音是标准gumbel分布。在我们把采样过程的确定性部分和随机性部分结合起来之后,我们在此基础上再用一个argmax来找到具有最大概率的类别。
Softmax
使用softmax替换不可导的argmax,用温度系数 τ \tau τ 来近似argmax:
p i ′ = exp ( g i + log p i τ ) ∑ j exp ( g j + log p j τ ) p_i^{\prime}=\frac{\exp \left(\frac{g_i+\log p_i}{\tau}\right)}{\sum_j \exp \left(\frac{g_j+\log p_j}{\tau}\right)} pi′=∑jexp(τgj+logpj)exp(τgi+logpi)
τ \tau τ 越大,越接近argmax。
参考
- CATEGORICAL REPARAMETERIZATION
WITH GUMBEL-SOFTMAX - 通俗易懂地理解Gumbel Softmax
这篇关于Gumbel Softmax的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!