本文主要是介绍GumbleSoftmax感性理解--可导式输出随机类别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
GumbleSoftmax
本文不涉及GumbleSoftmax的具体证明和推导,有需要请参见1,只是从感性角度来直观讲解为何要引入GumbleSoftmax,同时又为什么不用Gumblemax。
GumbleSoftmax提出是为了应对分布采样不可导的问题。举例而言,我们从网络经Softmax层输出了类别概率向量 p 1 = [ 0.9 , 0.1 , 0.1 ] p_1=[0.9,0.1,0.1] p1=[0.9,0.1,0.1]和 p 2 = [ 0.5 , 0.2 , 0.3 ] p_2=[0.5,0.2,0.3] p2=[0.5,0.2,0.3],那么如果我们训练网络最终的输出需求只是从中得到对应的类别结果(分类任务),那么 p 1 p_1 p1和 p 2 p_2 p2其实都是合理的,因为我们我们最终得到的都只会是 a r g m a x ( p ) = 0 argmax(p)=0 argmax(p)=0。但如果我们正在进行生成任务,这一类别结果只是一个中间值,而我们希望这一类别概率向量真正体现出了概率的含义,那么 p 1 , p 2 p_1,p_2 p1,p2就会有着显著的差异,后者采样出第1、2类的的结果要明显高于前者。
因此为了突出网络输出的概率属性,我们可以简单的依照这一概率向量进行采样即可,定一个均匀分布 U ( 0 , 1 ) U(0,1) U(0,1),落在哪个概率区间就认为输出哪一个类别,但这一采样操作是不可导的,也就无法使网络端到端训练。GumbleSoftmax的提出就是为了解决这一问题,它让网络输出类别随机的同时,又使得这一采样过程可导。一句话总结:GumbleSoftmaxd代替了网络中的 a r g m a x argmax argmax,引入了:
- 随机性:网络的输出真的变成了由最终概率向量决定的随机变量,即logit输出 [ 0.9 , 0.1 , 0.1 ] [0.9,0.1,0.1] [0.9,0.1,0.1]真的可能因抽样而判定为第2类;
- 可导性:这一抽样过程可导,可以融入到网络端到端训练过程中。(伪)
GumbleMax
为了让网络的输出类别真正的随机,我们需要先将对 a r g m a x argmax argmax进行替换,既然网络输出随机的就不可导的话,我们就利用重参数技巧将这一随机性放到另一个随机变量上,也就得到了GumbleMax,公式如下:
x = a r g m a x ( l o g ( x ) + G ) , \bold{x}=argmax(log(\bold{x})+\bold{G}), x=argmax(log(x)+G),
其中 x , G \bold{x},\bold{G} x,G分别是网络输出的概率向量、符合Gumble分布的噪声向量, G i = − l o g ( − l o g ( U i ) ) , U i U ( 0 , 1 ) G_i=-log(-log(U_i)),U_i~U(0,1) Gi=−log(−log(Ui)),Ui U(0,1)。这一噪声向量的引入就会使得argmax的输出结果发生扰动,变成一个随机变量。同样是之前的例子, l o g ( p 1 ) + G log(p_1)+\bold{G} log(p1)+G就有可能变为 [ 0.5 , 0.6 , 0.5 ] [0.5,0.6,0.5] [0.5,0.6,0.5]而使得最终输出类别为第1类,而 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)服从这一随机变量服从 x x x的离散分布列证明见附1。
通过引入GumbleMax,我们成功的为网络的类别输出引入了随机性。但可导性的问题并没有解决,因为这里仍然是存在了argmax。
GumbleSoftMax
GumbleSoftMax对GumbleMax的解决也很简单,它又把argmax替换成为了softmax,得到如下计算:
x = s o f t m a x ( ( l o g ( x ) + G ) / τ ) , \bold{x}=softmax((log(\bold{x})+\bold{G})/\tau), x=softmax((log(x)+G)/τ),
其中 τ \tau τ为为温度参数,这一算式中通过对argmax的软化实现了可导操作。至此,也就完成了为了网络输出引入可导随机性的目标。
矛盾
讨论至此,有个非常反直觉的考量,那就是相比于GumbleMax的硬输出onehot向量,GumbleSoftMax的输出似乎又变成了概率向量,我们想要得到的具体的类别输出,还要继续再取argmax也就是 a r g m a x ( s o f t m a x ( ( l o g ( x ) + G ) ) / τ ) argmax(softmax((log(\bold{x})+\bold{G}))/\tau) argmax(softmax((log(x)+G))/τ)。那么这不是仍然不可导,仍然返回了GumbleMax的窘境?因此这里依据个人理解要做出以下的澄清:
- 确实不可导,如果我们希望从GumbleSoftMax输出一个类别值,那么就必然引入argmax,也就必然不可导。而在实际过程中,我们则是回避了对argmax求导的问题,直接对 s o f t m a x ( ( l o g ( x ) + G ) ) / τ softmax((log(\bold{x})+\bold{G}))/\tau softmax((log(x)+G))/τ进行求导,具体可以参见pytorch中Gumblesoftmax的实现2。
- 既然如此,那为什么不照猫画虎在使用Gumblemax的时候就忽略argmax的存在,直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)求导?这是因为 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)本身才是我们想要求导的对象,而因为argmax本身不可导,所以引入了softmax来替代,也即我们相对 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]求导,迫不得已对 [ 0.8 , 0.1 , 0.1 ] [0.8,0.1,0.1] [0.8,0.1,0.1]求导,算是某种程度上的导数近似。而在1中的argmax本身也不是我们求导的对象,只是由于这一近似带来的补偿。而更进一步的,假设我们直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)进行求导,那么这一近似带来的误差只会更大,也让随机噪声的引入失去了意义,等价于对 l o g ( x ) log(x) log(x)求导。这也就是为什么开头的可导加了伪,因为我们是在对softmax求导,而不是argmax。
总结
整体而言,GumbleSoftmax通过引入了Gumble随机噪声使得输出的类别真正具有随机性,而将argmax软化为softmax则使得这一随机过程可导。
参考文献
Gumbel-Softmax Trick和Gumbel分布 ↩︎ ↩︎
请问用Gumbel-softmax的时候,怎么让softmax输出的概率分布转化成one-hot向量? ↩︎
这篇关于GumbleSoftmax感性理解--可导式输出随机类别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!