【Deep Learning】Variational Autoencoder ELBO:优美的数学推导

2024-04-13 00:52

本文主要是介绍【Deep Learning】Variational Autoencoder ELBO:优美的数学推导,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Variational Autoencoder

  • In this note, we talk about the generation model, where x x x represents the given dataset, z z z represents the latent variable, θ , ϕ \theta,\phi θ,ϕ denote the parameters of models.

Latent Variable Model

  • Generate x x x by latent variable z z z: p ( x , z ) = p ( x ) p ( x ∣ z ) p(x,z)=p(x)p(x|z) p(x,z)=p(x)p(xz)
  • Training: Maximum likelihood

L ( θ ) = ∑ x ∈ D log ⁡ p ( x ) = ∑ x ∈ D log ⁡ ∑ z p ( x , z ; θ ) = ∑ x ∈ D log ⁡ ∑ z q ( z ) p ( x , z ; θ ) q ( z ) Important Sampling ≥ ∑ x ∈ D ∑ z q ( z ) log ⁡ p ( x , z ; θ ) q ( z ) Concavcity of log \begin{align*} L(\theta)&=\sum_{x\in D}\log p(x)\\ &=\sum_{x\in D}\log \sum_{z}p(x,z;\theta)\\ &=\sum_{x\in D}\log \sum_{z} q(z)\frac{p(x,z;\theta)}{q(z)} & \text{Important Sampling}\\ &\ge\sum_{x\in D}\sum_{z}q(z)\log \frac{p(x,z;\theta)}{q(z)} & \text{Concavcity of log} \end{align*} L(θ)=xDlogp(x)=xDlogzp(x,z;θ)=xDlogzq(z)q(z)p(x,z;θ)xDzq(z)logq(z)p(x,z;θ)Important SamplingConcavcity of log

  • Assumption: ∑ z q ( z ) = 1 \sum_z q(z)=1 zq(z)=1. The summation can be regarded as expectation(just for simplicity)

ELBO

  • In the above deriviation, ∑ z q ( z ) log ⁡ p ( x , z ; θ ) q ( z ) \sum_zq(z)\log \frac{p(x,z;\theta)}{q(z)} zq(z)logq(z)p(x,z;θ) is the Evidence Lower Bound of log ⁡ p ( x ) \log p(x) logp(x)
  • When q ( z ) = p ( z ∣ x ; θ ) q(z)=p(z|x;\theta) q(z)=p(zx;θ),

∑ z q ( z ) log ⁡ p ( x , z ; θ ) q ( z ) = log ⁡ p ( x ; θ ) \sum_zq(z)\log\frac{p(x,z;\theta)}{q(z)}=\log p(x;\theta) zq(z)logq(z)p(x,z;θ)=logp(x;θ)

  • We can set q ( z ) = p ( z ∣ x ; θ ) q(z)=p(z|x;\theta) q(z)=p(zx;θ) to optimize a tight lowerbound of log ⁡ p ( x ; θ ) \log p(x;\theta) logp(x;θ)

    • We call p ( z ∣ x ; θ ) p(z|x;\theta) p(zx;θ) posterior.
    • Don’t know p ( z ∣ x ; θ ) p(z|x;\theta) p(zx;θ)? Use network q ( z ; ϕ ) q(z;\phi) q(z;ϕ) to paratermize p ( z ∣ x ) p(z|x) p(zx).
    • Optimize q ( z ; ϕ ) ≈ p ( z ∣ x ; θ ) q(z;\phi)\approx p(z|x;\theta) q(z;ϕ)p(zx;θ) and p ( x ∣ z ; θ ) p(x|z;\theta) p(xz;θ) alternatively.
  • Since we use q ( z ; ϕ ) q(z;\phi) q(z;ϕ) to approximate p ( z ∣ x ; θ ) p(z|x;\theta ) p(zx;θ), what is the distance metric between them?

    • K L ( q ∣ ∣ p ) = ∑ z q ( z ) log ⁡ q ( z ) p ( z ) KL(q||p)=\sum_z q(z)\log \frac{q(z)}{p(z)} KL(q∣∣p)=zq(z)logp(z)q(z)
      • Compared to K L ( p ∣ ∣ q ) KL(p||q) KL(p∣∣q), K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p) is reverse KL.
        • Empirically, We often use K L ( p ∣ ∣ q ) KL(p||q) KL(p∣∣q), where p p p is the groundtruth distribution, that’s why K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p) is ‘reverse’.
    • We call the procedure to find such ϕ \phi ϕ by Variational Inference: min ⁡ ϕ K L ( q ∣ ∣ p ) \min_\phi KL(q||p) minϕKL(q∣∣p).
  • Look at the optimization of K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p):
    K L ( q ( z ; ϕ ) ∣ ∣ p ( z ∣ x ) ) = ∑ z q ( z ; ϕ ) log ⁡ q ( z ; ϕ ) p ( z ∣ x ) = ∑ z q ( z ; ϕ ) log ⁡ q ( z ; ϕ ) p ( x ) p ( z , x ) = log ⁡ p ( x ) − ∑ z q ( z ; ϕ ) log ⁡ p ( z , x ) q ( z ; ϕ ) \begin{align*} KL(q(z;\phi)||p(z|x))&=\sum_{z}q(z;\phi)\log \frac{q(z;\phi)}{p(z|x)}\\ &=\sum_{z}q(z;\phi)\log \frac{q(z;\phi)p(x)}{p(z,x)}\\ &=\log p(x)-\sum_zq(z;\phi)\log \frac{p(z,x)}{q(z;\phi)} \end{align*} KL(q(z;ϕ)∣∣p(zx))=zq(z;ϕ)logp(zx)q(z;ϕ)=zq(z;ϕ)logp(z,x)q(z;ϕ)p(x)=logp(x)zq(z;ϕ)logq(z;ϕ)p(z,x)
    Amazing! ∑ z q ( z ; ϕ ) log ⁡ p ( z , x ) q ( z ; ϕ ) \sum_{z}q(z;\phi)\log \frac{p(z,x)}{q(z;\phi)} zq(z;ϕ)logq(z;ϕ)p(z,x) is just the ELBO! When we minimize K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p), we are also maximizing ELBO, which means the objective we alternatively trained for p ( x ∣ z ; θ ) p(x|z;\theta) p(xz;θ) and q ( z ; ϕ ) q(z;\phi) q(z;ϕ) is magically the same!

    What’s more, we can also find that
    log ⁡ p ( x ) = K L ( q ( z ; ϕ ) ∣ ∣ p ( z ∣ x ) ) + E L B O = A p p r o x E r r o r + E L B O \log p(x) = KL(q(z;\phi)||p(z|x)) + ELBO=ApproxError+ELBO logp(x)=KL(q(z;ϕ)∣∣p(zx))+ELBO=ApproxError+ELBO
    which verifies that ELBO is the lowerbound of log ⁡ p ( x ) \log p(x) logp(x), and there difference is exactly the approximate error between q ( z ; ϕ ) q(z;\phi) q(z;ϕ) and p ( z ∣ x ) p(z|x) p(zx).

  • Notice: q ( z ; ϕ ) ≈ p ( z ∣ x , θ ) q(z;\phi)\approx p(z|x,\theta) q(z;ϕ)p(zx,θ). q q q depends on x x x, hence we can use q ( z ∣ x ; ϕ ) q(z|x;\phi) q(zx;ϕ) instead of q ( z ; ϕ ) q(z;\phi) q(z;ϕ), named Amortized Variational Inference.

  • Now, only ELBO is our only joint objective. Train θ , ϕ \theta,\phi θ,ϕ together!
    J ( θ , ϕ ; x ) = ∑ z q ( z ∣ x ; ϕ ) log ⁡ p ( x , z ; θ ) q ( z ∣ x ; ϕ ) = ∑ z q ( z ∣ x ; ϕ ) ( log ⁡ p ( x ∣ z ; θ ) + log ⁡ p ( z ; θ ) − log ⁡ q ( z ∣ x ; ϕ ) ) = ∑ z q ( z ∣ x ; ϕ ) log ⁡ p ( x ∣ z ; θ ) − ∑ z q ( z ∣ x ; ϕ ) log ⁡ q ( z ∣ x ; ϕ ) log ⁡ p ( z ; θ ) = E z ∼ q ( ⋅ ∣ x ; ϕ ) log ⁡ p ( x ∣ z ; θ ) − K L ( q ( z ∣ x ; ϕ ) ∣ ∣ p ( z ; θ ) ) \begin{align*} J(\theta,\phi;x)&=\sum_z q(z|x;\phi)\log\frac{p(x,z;\theta)}{q(z|x;\phi)}\\ &=\sum_z q(z|x;\phi)\left( \log p(x|z;\theta)+\log p(z;\theta)-\log q(z|x;\phi)\right)\\ &=\sum_z q(z|x;\phi)\log p(x|z;\theta)-\sum_zq(z|x;\phi)\frac{\log q(z|x;\phi)}{\log p(z;\theta)}\\ &=\mathbb{E}_{z\sim q(\cdot|x;\phi)}\log p(x|z;\theta )-KL(q(z|x;\phi)||p(z;\theta)) \end{align*} J(θ,ϕ;x)=zq(zx;ϕ)logq(zx;ϕ)p(x,z;θ)=zq(zx;ϕ)(logp(xz;θ)+logp(z;θ)logq(zx;ϕ))=zq(zx;ϕ)logp(xz;θ)zq(zx;ϕ)logp(z;θ)logq(zx;ϕ)=Ezq(x;ϕ)logp(xz;θ)KL(q(zx;ϕ)∣∣p(z;θ))

VAE

  • Pratically, we obtain VAE from E L B O ELBO ELBO.

  • Assume that
    p ( z ) ∼ N ( 0 , I ) q ( z ∣ x ; ϕ ) ∼ N ( μ ϕ ( x ) , σ ϕ ( x ) ) p ( x ∣ z ; θ ) ∼ N ( μ θ ( z ) , σ μ ( z ) ) p(z)\sim N(0,I)\\q(z|x;\phi)\sim N(\mu_\phi(x),\sigma_\phi(x))\\ p(x|z;\theta)\sim N(\mu_\theta(z),\sigma_\mu(z)) p(z)N(0,I)q(zx;ϕ)N(μϕ(x),σϕ(x))p(xz;θ)N(μθ(z),σμ(z))
    They are all Gaussian, where the mean and variance are from net work.

  • Let q ( z ∣ x ; ϕ ) q(z|x;\phi) q(zx;ϕ) be the encoder, p ( x ∣ z ; θ ) p(x|z;\theta) p(xz;θ) be the decoder, then E z ∼ q ( ⋅ ∣ x ; ϕ ) log ⁡ p ( x ∣ z ; θ ) \mathbb{E}_{z\sim q(\cdot|x;\phi)}\log p(x|z;\theta ) Ezq(x;ϕ)logp(xz;θ) represents reconstruction error:

    • The error after encoding into latent space, then decoding into the original space.
    • We wish this term big, so that the original data can be recovered with high probability.
  • Re-parameterization trick:

    • In E z ∼ q ( ⋅ ∣ x ; ϕ ) log ⁡ p ( x ∣ z ; θ ) \mathbb{E}_{z\sim q(\cdot|x;\phi)}\log p(x|z;\theta ) Ezq(x;ϕ)logp(xz;θ) term, ϕ \phi ϕ is the sampling parameters, whose gradient can’t be computed.
    • Sample z ′ ∼ N ( 0 , I ) z'\sim N(0,I) zN(0,I), then compute z = μ + z ′ ⋅ σ z=\mu+z'\cdot \sigma z=μ+zσ.

Conclusion

  • The amazing and elegent mathematical deviation behind VAE inspires me to write down this blog.
  • Furthermore, VAE shows its great stability through many tasks, compared to GAN. There are still more Pro and Cons to talk about.

这篇关于【Deep Learning】Variational Autoencoder ELBO:优美的数学推导的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/898715

相关文章

Deep Learning复习笔记0

Key Concept: Embedding: learned dense, continuous, low-dimensional representations of object 【将难以表示的对象(如图片,文本等)用连续的低维度的方式表示】 RNN: Recurrent Neural Network -> for processing sequential data (time se

OSG数学基础:坐标系变换

三维实体对象需要经过一系列的坐标变换才能正确、真实地显示在屏幕上。在一个场景中,当读者对场景中的物体进行各种变换及相关操作时,坐标系变换是非常频繁的。坐标系变换通常包括:世界坐标系-物体坐标系变换、物体坐标系-世界坐标系变换和世界坐标系-屏幕坐标系变换(一个二维平面坐标系,即显示器平面,是非常标准的笛卡尔坐标系的第一象限区域)。 世界坐标系-物体坐标系变换 它描述的问题主要是关于物体本身的

OSG数学基础:坐标系统

坐标系是一个精确定位对象位置的框架,所有的图形变换都是基于一定的坐标系进行的。三维坐标系总体上可以分为两大类:左手坐标系和右手坐标系。常用的坐标系:世界坐标系、物体坐标系和摄像机坐标系。 世界坐标系 世界坐标系是一个特殊的坐标系,它建立了描述其他坐标系所需要的参考框架。从另一方面说,能够用世界坐标系来描述其他坐标系的位置,而不能用更大的、外部的坐标系来描述世界坐标系。世界坐标系也被广泛地

2023-2024 学年第二学期小学数学六年级期末质量检测模拟(制作:王胤皓)(90分钟)

word效果预览: 一、我会填 1. 1.\hspace{0.5em} 1. 一个多位数,亿位上是次小的素数,千位上是最小的质数的立方,十万位是 10 10 10 和 15 15 15 的最大公约数,万位是最小的合数,十位上的数既不是质数也不是合数,这个数是 ( \hspace{4em} ),约等于 ( \hspace{1em} ) 万 2. 2.\hspace{0.5em} 2.

GIM: Learning Generalizable Image Matcher From Internet Videos

【引用格式】:Shen X, Yin W, Müller M, et al. GIM: Learning Generalizable Image Matcher From Internet Videos[C]//The Twelfth International Conference on Learning Representations. 2023. 【网址】:https://arxiv.or

Program-of-Thoughts(PoT):结合Python工具和CoT提升大语言模型数学推理能力

Program of Thoughts Prompting:Disentangling Computation from Reasoning for Numerical Reasoning Tasks github:https://github.com/wenhuchen/Program-of-Thoughts 一、动机 数学运算和金融方面都涉及算术推理。先前方法采用监督训练的形式,但这种方

【数学】100332. 包含所有 1 的最小矩形面积 II

本文涉及知识点 数学 LeetCode100332. 包含所有 1 的最小矩形面积 II 给你一个二维 二进制 数组 grid。你需要找到 3 个 不重叠、面积 非零 、边在水平方向和竖直方向上的矩形,并且满足 grid 中所有的 1 都在这些矩形的内部。 返回这些矩形面积之和的 最小 可能值。 注意,这些矩形可以相接。 示例 1: 输入: grid = [[1,0,1],[1,1,1]]

组合数学、圆排列、离散数学多重集合笔记

自用 如果能帮到您,那也值得高兴 知识点 离散数学经典题目 多重集合组合 补充容斥原理公式 隔板法题目 全排列题目:

数学建模 —— 查找数据

目录 百度搜索技巧 完全匹配搜索:查询词的外边加上双引号“ ” 标题必含关键词:查询词前加上intitle: 搜索文档:空格再输入filetype:文件格式 去掉不想要的:查询词后面加空格后加减号与关键字  知网查文献 先看知网的硕博士论文 高级检索:想了解神经网络在信贷策略中的应用,想找一些相关的硕博论文   其他网站查文献  谷歌学术镜像 Open Access Libr

数学位运算

位运算 位运算有&,|,^,<<和>> 按位与 参加运算的两个数据,按二进制位进行“与”运算。 运算规则:0&0=0; 0&1=0; 1&0=0; 1&1=1; 即:两位同时为“1”,结果才为“1”,否则为0 例如: 11 & 10    即 1011 & 1010 = 1010    因此,11 & 10的值得10 按位或 参加运算的两个对象,按二进制位进行“或”运算。 运算规则:0|