本文主要是介绍扩散模型详细推导过程——训练与采样,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
扩散模型的训练与采样算法
训练目标的推导
需要使得去噪过程所产生的 x ( i ) \boldsymbol{x}^{(i)} x(i)的总体出现概率最大,先不考虑第几个样本,省略上标,即最大化 p ( x ∣ θ 1 : T ) p(\boldsymbol{x}|\theta_{1:T}) p(x∣θ1:T),也等价于最大化 log [ p ( x ∣ θ 1 : T ) ] \log \left[p(\boldsymbol{x}|\theta_{1:T})\right] log[p(x∣θ1:T)]。直接最大化该式是无从下手的,考虑寻找该式的一个置信下界ELBO:
log [ p ( x ∣ θ 1 : T ) ] = log [ p ( x , z 1 : T ∣ θ 1 : T ) d z 1 : T ] = log [ ∫ q ( z 1 : T ∣ x ) p ( x , z 1 : T ∣ θ 1 : T ) q ( z 1 : T ∣ x ) d z 1 : T ] ≥ ∫ q ( z 1 : T ∣ x ) log [ p ( z 0 , z 1 : T ∣ θ 1 : T ) q ( z 1 : T ∣ x ) ] d z 1 : T \begin{align}\log \left[p(\boldsymbol{x}|\theta_{1:T})\right]&=\log\left[p(\boldsymbol{x}, \boldsymbol{z}_{1:T}|\theta_{1:T})d\boldsymbol{z}_{1:T}\right]\\&=\log\left[\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\frac{p(\boldsymbol{x}, \boldsymbol{z}_{1:T}|\theta_{1:T})}{q(\boldsymbol{z}_{1:T}|\boldsymbol{x})}d\boldsymbol{z}_{1:T}\right]\\ &\ge \int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\log\left[\frac{p(\boldsymbol{z}_0, \boldsymbol{z}_{1:T}|\theta_{1:T})}{q(\boldsymbol{z}_{1:T}|\boldsymbol{x})}\right]d\boldsymbol{z}_{1:T}\end{align} log[p(x∣θ1:T)]=log[p(x,z1:T∣θ1:T)dz1:T]=log[∫q(z1:T∣x)q(z1:T∣x)p(x,z1:T∣θ1:T)dz1:T]≥∫q(z1:T∣x)log[q(z1:T∣x)p(z0,z1:T∣θ1:T)]dz1:T
而其中,
log [ p ( x , z 1 : T ∣ θ 1 : T ) q ( z 1 : T ∣ x ) ] = log ⌊ p ( x ∣ z 1 , θ 1 ) ∏ t = 2 T p ( z t − 1 ∣ z t , θ t ) ⋅ p ( z T ) q ( z 1 ∣ x ) ∏ t = 2 T q ( z t ∣ z t − 1 ) ⌋ = log [ p ( x ∣ z 1 , θ 1 ) q ( z 1 ∣ x ) ] + log [ ∏ t = 2 T p ( z t − 1 ∣ z t , θ t ) ∏ t = 2 T q ( z t ∣ z t − 1 ) ] + log [ p ( z T ) ] \begin{align}\log\left[\frac{p(\boldsymbol{x},\boldsymbol{z}_{1: T}|\theta_{1: T})}{q(\boldsymbol{z}_{1: T}|\boldsymbol{x})}\right]& \begin{aligned}=\log\left\lfloor\frac{p(\boldsymbol{x}|\boldsymbol{z}_1,\theta_1)\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\theta_t)\cdot p(\boldsymbol{z}_T)}{q(\boldsymbol{z}_1|\boldsymbol{x})\prod_{t=2}^Tq(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})}\right\rfloor\end{aligned} \\&\begin{aligned}=\log\left[\frac{p(\boldsymbol{x}|\boldsymbol{z}_1,\theta_1)}{q(\boldsymbol{z}_1|\boldsymbol{x})}\right]+\log\left[\frac{\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\theta_t)}{\prod_{t=2}^Tq(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})}\right]+\log\Bigl[p(\boldsymbol{z}_T)\Bigr]\end{aligned}\end{align} log[q(z1:T∣x)p(x,z1:T∣θ1:T)]=log⌊q(z1∣x)∏t=2Tq(zt∣zt−1)p(x∣z1,θ1)∏t=2Tp(zt−1∣zt,θt)⋅p(zT)⌋=log[q(z1∣x)p(x∣z1,θ1)]+log[∏t=2Tq(zt∣zt−1)∏t=2Tp(zt−1∣zt,θt)]+log[p(zT)]
由于扩散过程的马尔科夫链性质
q ( z t ∣ z t − 1 ) = q ( z t ∣ z t − 1 , x ) = p ( z t − 1 ∣ z t , x ) q ( z t ∣ x ) q ( z t − 1 ∣ x ) \begin{equation}q(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})=q(\boldsymbol{z}_t|\boldsymbol{z}_{t-1},\boldsymbol{x})=\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})q(\boldsymbol{z}_t|\boldsymbol{x})}{q(\boldsymbol{z}_{t-1}|\boldsymbol{x})}\end{equation} q(zt∣zt−1)=q(zt∣zt−1,x)=q(zt−1∣x)p(zt−1∣zt,x)q(zt∣x)
所以该展开式可以继续简化为:
log [ p ( x , z 1 : T ∣ θ 1 : T ) q ( z 1 : T ∣ x ) ] = log [ p ( x ∣ z 1 , θ 1 ) q ( z 1 ∣ x ) ] + log [ ∏ t = 2 T p ( z t − 1 ∣ z t , θ t ) ⋅ q ( z 0 ∣ x ) ∏ t = 2 T p ( z t − 1 ∣ z t , x ) ⋅ q ( z T ∣ x ) ] + log [ p ( z T ) ] = log [ p ( x ∣ z 1 , θ 1 ) ] + log [ ∏ t = 2 T p ( z t − 1 ∣ z t , θ t ) ∏ t = 2 T p ( z t − 1 ∣ z t , x ) ] + log [ p ( z T ) q ( z T ∣ x ) ] ≈ log [ p ( x ∣ z 1 , θ 1 ) ] + ∑ t = 2 T log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] \begin{align}\log\left[\frac{p(\boldsymbol{x},\boldsymbol{z}_{1:T}|\theta_{1:T})}{q(\boldsymbol{z}_{1:T}|\boldsymbol{x})}\right]&=\log\left[\frac{p(\boldsymbol{x}|\boldsymbol{z}_1,\theta_1)}{q(\boldsymbol{z}_1|\boldsymbol{x})}\right]+\log\left[\frac{\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\theta_t)\cdot q(\boldsymbol{z}_{0}|\boldsymbol{x})}{\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})\cdot q(\boldsymbol{z}_T|\boldsymbol{x})}\right]+\log\Bigl[p(\boldsymbol{z}_T)\Bigr]\\&=\log\left[p(\boldsymbol{x}|\boldsymbol{z}_1,\theta_1)\right]+\log\left[\frac{\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\theta_t)}{\prod_{t=2}^Tp(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})}\right]+\log\left[\frac{p(\boldsymbol{z}_T)}{q(\boldsymbol{z}_T|\boldsymbol{x})}\right]\\&\approx\log\left[p(\boldsymbol{x}|\boldsymbol{z}_1,\theta_1)\right]+\sum_{t=2}^T\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\theta_t)}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})}\right]\end{align} log[q(z1:T∣x)p(x,z1:T∣θ1:T)]=log[q(z1∣x)p(x∣z1,θ1)]+log[∏t=2Tp(zt−1∣zt,x)⋅q(zT∣x)∏t=2Tp(zt−1∣zt,θt)⋅q(z0∣x)]+log[p(zT)]=log[p(x∣z1,θ1)]+log[∏t=2Tp(zt−1∣zt,x)∏t=2Tp(zt−1∣zt,θt)]+log[q(zT∣x)p(zT)]≈log[p(x∣z1,θ1)]+t=2∑Tlog[p(zt−1∣zt,x)p(zt−1∣zt,θt)]
第二行到第三行中, p ( z T ) p(\boldsymbol{z}_T) p(zT)为标准的高斯分布,而 q ( z T ∣ x ) q(\boldsymbol{z}_T|\boldsymbol{x}) q(zT∣x)近似为标准的高斯分布,故 log [ p ( z T ) q ( z T ∣ x ) ] ≈ log 1 = 0 \log\left[\frac{p(\boldsymbol{z}_T)}{q(\boldsymbol{z}_T|\boldsymbol{x})}\right]\approx\log 1=0 log[q(zT∣x)p(zT)]≈log1=0。
至此,可以完整地写出
log [ p ( x ∣ θ 1 : T ) ] ≥ log [ ∫ q ( z 1 : T ∣ x ) p ( x , z 1 : T ∣ θ 1 : T ) q ( z 1 : T ∣ x ) d z 1 : T ] ≈ ∫ q ( z 1 : T ∣ x ) ( log [ p ( x ∣ z 1 , θ 1 ) ] + ∑ t = 2 T log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] ) d z 1 : T = ∫ q ( z 1 : T ∣ x ) log [ p ( x ∣ z 1 , θ 1 ) ] d z 1 : T + ∫ q ( z 1 : T ∣ x ) ∑ t = 2 T log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] d z 1 : T = ∫ q ( z 1 ∣ x ) log [ p ( x ∣ z 1 , θ 1 ) ] d z 1 + ∑ t = 2 T ∫ q ( z 1 : T ∣ x ) log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] d z 1 : T = E q ( z 1 ∣ z 0 ) [ log [ p ( x ∣ z 1 , θ 1 ) ] ] + ∑ t = 2 T ∬ q ( z t − 1 , z t ∣ x ) log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] d z t − 1 d z t = E q ( z 1 ∣ z 0 ) [ log [ p ( x ∣ z 1 , θ 1 ) ] ] + ∑ t = 2 T ∬ q ( z t ∣ x ) p ( z t − 1 ∣ z t , x ) log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] d z t − 1 d z t = E q ( z 1 ∣ z 0 ) [ log [ p ( x ∣ z 1 , θ 1 ) ] ] + ∑ t = 2 T ∫ q ( z t ∣ x ) ( ∫ p ( z t − 1 ∣ z t , x ) log [ p ( z t − 1 ∣ z t , θ t ) p ( z t − 1 ∣ z t , x ) ] d z t − 1 ) d z t = E q ( z 1 ∣ z 0 ) [ log [ p ( x ∣ z 1 , θ 1 ) ] ] − ∑ t = 2 T ∫ q ( z t ∣ x ) ⋅ D K L [ p ( z t − 1 ∣ z t , x ) ∣ ∣ p ( z t − 1 ∣ z t , θ t ) ] d z t = E q ( z 1 ∣ z 0 ) [ log [ p ( x ∣ z 1 , θ 1 ) ] ] − ∑ t = 2 T E q ( z t ∣ x ) [ D K L [ p ( z t − 1 ∣ z t , x ) ∣ ∣ p ( z t − 1 ∣ z t , θ t ) ] ] = E q ( z 1 ∣ x ) [ log [ N ( f 1 ( z 1 , θ 1 ) , σ 1 2 I ) ] ] − ∑ t = 2 T E q ( z t ∣ x ) [ 1 2 σ t 2 ∥ ( 1 − α t − 1 ) 1 − α t 1 − β t z t + α t − 1 β t 1 − α t x − f t [ z t , θ t ] ∥ 2 + C ] ≈ log [ N ( f 1 ( z 1 ∗ , θ 1 ) , σ 1 2 I ) ] − ∑ t = 2 T 1 2 σ t 2 ∥ ( 1 − α t − 1 ) 1 − α t 1 − β t z t ∗ + α t − 1 β t 1 − α t x − f t [ z t ∗ , θ t ] ∥ 2 − C \begin{align}\log \left[p(\boldsymbol{x}|\theta_{1:T})\right] &\ge \log\left[\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\frac{p(\boldsymbol{x}, \boldsymbol{z}_{1:T}|\theta_{1:T})}{q(\boldsymbol{z}_{1:T}|\boldsymbol{x})}d\boldsymbol{z}_{1:T}\right]\\ &\approx\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\left(\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]+\sum_{t=2}^{T}\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]\right)d\boldsymbol{z}_{1:T}\\&=\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]d\boldsymbol{z}_{1:T}+\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\sum_{t=2}^{T}\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]d\boldsymbol{z}_{1:T}\\&=\int q(\boldsymbol{z}_{1}|\boldsymbol{x})\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]d\boldsymbol{z}_{1}+\sum_{t=2}^{T}\int q(\boldsymbol{z}_{1:T}|\boldsymbol{x})\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]d\boldsymbol{z}_{1:T}\\ &= E_{q(\boldsymbol{z}_{1}|\boldsymbol{z}_{0})}\left[\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]\right]+\sum_{t=2}^{T}\iint q(\boldsymbol{z}_{t-1},\boldsymbol{z}_{t}|\boldsymbol{x})\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]d\boldsymbol{z}_{t-1}d\boldsymbol{z}_{t}\\ &= E_{q(\boldsymbol{z}_{1}|\boldsymbol{z}_{0})}\left[\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]\right]+\sum_{t=2}^{T}\iint q(\boldsymbol{z}_{t}|\boldsymbol{x})p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]d\boldsymbol{z}_{t-1}d\boldsymbol{z}_{t}\\ &= E_{q(\boldsymbol{z}_{1}|\boldsymbol{z}_{0})}\left[\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]\right]+\sum_{t=2}^{T}\int q(\boldsymbol{z}_{t}|\boldsymbol{x})\left(\int p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})\log\left[\frac{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_{t})}{p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})}\right]d\boldsymbol{z}_{t-1}\right)d\boldsymbol{z}_{t}\\ &= E_{q(\boldsymbol{z}_{1}|\boldsymbol{z}_{0})}\left[\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]\right]-\sum_{t=2}^{T}\int q(\boldsymbol{z}_{t}|\boldsymbol{x})\cdot D_{KL}\left[p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})||p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t, \theta_t)\right]d\boldsymbol{z}_{t}\\&= E_{q(\boldsymbol{z}_{1}|\boldsymbol{z}_{0})}\left[\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]\right]-\sum_{t=2}^{T}E_{q(\boldsymbol{z}_{t}|\boldsymbol{x})}\left[ D_{KL}\left[p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})||p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t, \theta_t)\right]\right]\\ &= E_{q(\boldsymbol{z}_1|\boldsymbol{x})}\left[\log\left[N(f_1(\boldsymbol{z}_1,\theta_1),\sigma_1^2\boldsymbol{I})\right]\right] \\ &- \sum_{t=2}^{T}E_{q(\boldsymbol{z}_{t}|\boldsymbol{x})}\left[\frac{1}{2\sigma_t^2}\left\|\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\boldsymbol{z}_t+\frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}\boldsymbol{x}-f_t[\boldsymbol{z}_t,\theta_t]\right\|^2+C\right]\\ &\approx \log\left[N(f_1(\boldsymbol{z}_1^*,\theta_1),\sigma_1^2\boldsymbol{I})\right]-\sum_{t=2}^{T}\frac{1}{2\sigma_t^2}\left\|\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\boldsymbol{z}_t^*+\frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}\boldsymbol{x}-f_t[\boldsymbol{z}_t^*,\theta_t]\right\|^2-C\end{align} log[p(x∣θ1:T)]≥log[∫q(z1:T∣x)q(z1:T∣x)p(x,z1:T∣θ1:T)dz1:T]≈∫q(z1:T∣x)(log[p(x∣z1,θ1)]+t=2∑Tlog[p(zt−1∣zt,x)p(zt−1∣zt,θt)])dz1:T=∫q(z1:T∣x)log[p(x∣z1,θ1)]dz1:T+∫q(z1:T∣x)t=2∑Tlog[p(zt−1∣zt,x)p(zt−1∣zt,θt)]dz1:T=∫q(z1∣x)log[p(x∣z1,θ1)]dz1+t=2∑T∫q(z1:T∣x)log[p(zt−1∣zt,x)p(zt−1∣zt,θt)]dz1:T=Eq(z1∣z0)[log[p(x∣z1,θ1)]]+t=2∑T∬q(zt−1,zt∣x)log[p(zt−1∣zt,x)p(zt−1∣zt,θt)]dzt−1dzt=Eq(z1∣z0)[log[p(x∣z1,θ1)]]+t=2∑T∬q(zt∣x)p(zt−1∣zt,x)log[p(zt−1∣zt,x)p(zt−1∣zt,θt)]dzt−1dzt=Eq(z1∣z0)[log[p(x∣z1,θ1)]]+t=2∑T∫q(zt∣x)(∫p(zt−1∣zt,x)log[p(zt−1∣zt,x)p(zt−1∣zt,θt)]dzt−1)dzt=Eq(z1∣z0)[log[p(x∣z1,θ1)]]−t=2∑T∫q(zt∣x)⋅DKL[p(zt−1∣zt,x)∣∣p(zt−1∣zt,θt)]dzt=Eq(z1∣z0)[log[p(x∣z1,θ1)]]−t=2∑TEq(zt∣x)[DKL[p(zt−1∣zt,x)∣∣p(zt−1∣zt,θt)]]=Eq(z1∣x)[log[N(f1(z1,θ1),σ12I)]]−t=2∑TEq(zt∣x)[2σt21 1−αt(1−αt−1)1−βtzt+1−αtαt−1βtx−ft[zt,θt] 2+C]≈log[N(f1(z1∗,θ1),σ12I)]−t=2∑T2σt21 1−αt(1−αt−1)1−βtzt∗+1−αtαt−1βtx−ft[zt∗,θt] 2−C
其中,
log [ p ( x ∣ z 1 , θ 1 ) ] = log [ N ( f 1 ( z 1 , θ 1 ) , σ 1 2 I ) ] \begin{equation}\log\left[p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_{1})\right]=\log\left[N(f_1(\boldsymbol{z}_1,\theta_1),\sigma_1^2\boldsymbol{I})\right]\end{equation} log[p(x∣z1,θ1)]=log[N(f1(z1,θ1),σ12I)]
D K L [ p ( z t − 1 ∣ z t , x ) ∣ ∣ p ( z t − 1 ∣ z t , θ t ) ] = 1 2 σ t 2 ∥ ( 1 − α t − 1 ) 1 − α t 1 − β t z t + α t − 1 β t 1 − α t x − f t [ z t , θ t ] ∥ 2 + C \begin{equation}D_{KL}\left[p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})||p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t, \theta_t)\right]=\frac{1}{2\sigma_t^2}\left\|\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\boldsymbol{z}_t+\frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}\boldsymbol{x}-f_t[\boldsymbol{z}_t,\theta_t]\right\|^2+C\end{equation} DKL[p(zt−1∣zt,x)∣∣p(zt−1∣zt,θt)]=2σt21 1−αt(1−αt−1)1−βtzt+1−αtαt−1βtx−ft[zt,θt] 2+C
两个均值已使用蒙特卡洛方法近似, z 1 ∗ \boldsymbol{z}_1^* z1∗和 z t ∗ \boldsymbol{z}_t^* zt∗是分别从 q ( z 1 ∣ x ) q(\boldsymbol{z}_1|\boldsymbol{x}) q(z1∣x)和 q ( z t ∣ x ) q(\boldsymbol{z}_t|\boldsymbol{x}) q(zt∣x)中采样出来的某个样本,实际优化时常数 C C C不用考虑。
最终的目标是 max ∏ i = 1 n p ( x ( i ) ∣ θ 1 : T ) \max\prod_{i=1}^np(\boldsymbol{x}^{(i)}|\theta_{1:T}) max∏i=1np(x(i)∣θ1:T),等价于 max ∑ i = 1 n log [ p ( x ( i ) ∣ θ 1 : T ) ] \max\sum_{i=1}^n\log \left[p(\boldsymbol{x}^{(i)}|\theta_{1:T})\right] max∑i=1nlog[p(x(i)∣θ1:T)],其中 n n n为训练样本的总数。
θ ˉ 1 : T = arg min − ∑ i = 1 n log [ p ( x ( i ) ∣ θ 1 : T ) ] = arg min − ∑ i = 1 n [ log [ N ( f t ( z 1 ( i ) , θ 1 ) , σ 1 2 I ) ] ⏟ ①重建损失 − ∑ t = 2 T 1 2 σ t 2 ∥ ( 1 − α t − 1 ) 1 − α t 1 − β t z t ( i ) + α t − 1 β t 1 − α t x ( i ) ⏟ ② p ( z t − 1 ∣ z t , x ) 的均值 − f t [ z t ( i ) , θ t ] ⏟ ③神经网络估计的均值 ∥ 2 ] \begin{align}\bar{\theta}_{1:T}&=\arg\min-\sum_{i=1}^n\log \left[p(\boldsymbol{x}^{(i)}|\theta_{1:T})\right]\\ &= \arg\min\\ &-\sum_{i=1}^n\left[\underbrace{\log\left[N(f_t(\boldsymbol{z}_1^{(i)},\theta_1),\sigma_1^2\boldsymbol{I})\right]}_{①重建损失}-\sum_{t=2}^{T}\frac{1}{2\sigma_t^2}\left\|\underbrace{\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\boldsymbol{z}_t^{(i)}+\frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}\boldsymbol{x}^{(i)}}_{②p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})的均值}-\underbrace{f_t[\boldsymbol{z}_t^{(i)},\theta_t]}_{③神经网络估计的均值}\right\|^2\right] \end{align} θˉ1:T=argmin−i=1∑nlog[p(x(i)∣θ1:T)]=argmin−i=1∑n ①重建损失 log[N(ft(z1(i),θ1),σ12I)]−t=2∑T2σt21 ②p(zt−1∣zt,x)的均值 1−αt(1−αt−1)1−βtzt(i)+1−αtαt−1βtx(i)−③神经网络估计的均值 ft[zt(i),θt] 2
①重建损失:实际上是近似已知 z 1 \boldsymbol{z}_1 z1的情况下 x \boldsymbol{x} x的概率密度函数。
② p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)的均值。
③神经网络估计的均值。
该目标函数实际上是在做两件事:一是使得最终 x \boldsymbol{x} x的出现概率最大;二是使得解码过程中神经网络估计的均值尽可能逼近 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)的均值。
由于 z t = α t x + 1 − α t ϵ t , t = 1 , 2 , ⋯ , T \boldsymbol{z}_t=\sqrt{\alpha_t}\boldsymbol{x}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon}_t, t=1,2,\cdots,T zt=αtx+1−αtϵt,t=1,2,⋯,T,将 x \boldsymbol{x} x替换为 z t \boldsymbol{z}_t zt的表达式:
x = 1 α t z t − 1 − α t α t ϵ t \begin{equation}\boldsymbol{x}=\frac{1}{\sqrt{\alpha_t}}\boldsymbol{z}_t-\frac{\sqrt{1-\alpha_t}}{\sqrt{\alpha_t}}\boldsymbol{\epsilon}_t\end{equation} x=αt1zt−αt1−αtϵt
将该表达式代入目标函数,并利用 α t − 1 α t = 1 1 − β t \frac{\sqrt{\alpha_{t-1}}}{\sqrt{\alpha_t}}=\frac{1}{\sqrt{1-\beta_t}} αtαt−1=1−βt1,可继续简化为
− ∑ i = 1 n [ log [ N ( f t ( z 1 ( i ) , θ 1 ) , σ 1 2 I ) ] ⏟ ①重建损失 − ∑ t = 2 T 1 2 σ t 2 ∥ 1 1 − β t z t ( i ) − β t ( 1 − β t ) ( 1 − α t ) ϵ t ( i ) ⏟ ② p ( z t − 1 ∣ z t , z 0 ) 的均值 − f t [ z t ( i ) , θ t ] ⏟ ③神经网络预测的均值 ∥ 2 ] -\sum_{i=1}^n\left[\underbrace{\log\left[N(f_t(\boldsymbol{z}_1^{(i)},\theta_1),\sigma_1^2\boldsymbol{I})\right]}_{①重建损失}-\sum_{t=2}^{T}\frac{1}{2\sigma_t^2}\left\|\underbrace{\frac{1}{\sqrt{1-\beta_t}}\boldsymbol{z}_t^{(i)}-\frac{\beta_t}{\sqrt{(1-\beta_t)(1-\alpha_t)}}\boldsymbol{\epsilon}_t^{(i)}}_{②p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{z}_{0})的均值}-\underbrace{f_t[\boldsymbol{z}_t^{(i)},\theta_t]}_{③神经网络预测的均值}\right\|^2\right] −i=1∑n ①重建损失 log[N(ft(z1(i),θ1),σ12I)]−t=2∑T2σt21 ②p(zt−1∣zt,z0)的均值 1−βt1zt(i)−(1−βt)(1−αt)βtϵt(i)−③神经网络预测的均值 ft[zt(i),θt] 2
其中③由 z t \boldsymbol{z}_t zt和 g t [ z t , θ t ] g_t[\boldsymbol{z}_t,\theta_t] gt[zt,θt]计算得到:
f t [ z t , θ t ] = 1 1 − β t z t − β t ( 1 − β t ) ( 1 − α t ) g t [ z t , θ t ] \begin{equation}f_t[\boldsymbol{z}_t,\theta_t]=\frac{1}{\sqrt{1-\beta_t}}\boldsymbol{z}_t-\frac{\beta_t}{\sqrt{(1-\beta_t)(1-\alpha_t)}}g_t[\boldsymbol{z}_t,\theta_t]\end{equation} ft[zt,θt]=1−βt1zt−(1−βt)(1−αt)βtgt[zt,θt]
其中 g t g_t gt是用于估计噪声 ϵ t \boldsymbol{\epsilon}_t ϵt的神经网络,而 z t \boldsymbol{z}_t zt在采样阶段是已知的。参数为了方便统一仍记为 θ t \theta_t θt。将该表达式代入,目标函数变为:
θ ˉ 1 : T = arg min − ∑ i = 1 n [ log [ N ( f 1 ( z 1 ( i ) , θ 1 ) , σ 1 2 I ) ] ⏟ ①重建损失 − ∑ t = 2 T β t 2 2 σ t 2 ( 1 − β t ) ( 1 − α t ) ∥ g t [ z t ( i ) , θ t ] ⏟ ②神经网络预测的噪声 − ϵ t ( i ) ⏟ ③对样本 x ( i ) 在第 t 步添加的噪声 ∥ 2 ] \begin{align}\bar{\theta}_{1:T}=\arg\min-\sum_{i=1}^n\left[\underbrace{\log\left[N(f_1(\boldsymbol{z}_1^{(i)},\theta_1),\sigma_1^2\boldsymbol{I})\right]}_{①重建损失}-\sum_{t=2}^{T}\frac{\beta_t^2}{2\sigma_t^2(1-\beta_t)(1-\alpha_t)}\left\|\underbrace{g_t[\boldsymbol{z}_t^{(i)},\theta_t]}_{②神经网络预测的噪声}-\underbrace{\boldsymbol{\epsilon}_t^{(i)}}_{③对样本\boldsymbol{x}^{(i)}在第t步添加的噪声}\right\|^2\right]\end{align} θˉ1:T=argmin−i=1∑n ①重建损失 log[N(f1(z1(i),θ1),σ12I)]−t=2∑T2σt2(1−βt)(1−αt)βt2 ②神经网络预测的噪声 gt[zt(i),θt]−③对样本x(i)在第t步添加的噪声 ϵt(i) 2
对于上式中①,由多元高斯分布的定义可以算出, Σ = σ 1 2 I \Sigma=\sigma_1^2I Σ=σ12I, ∣ Σ ∣ 1 / 2 |\Sigma|^{1/2} ∣Σ∣1/2为一个常数, Σ − 1 = 1 σ 1 2 I \Sigma^{-1}=\frac{1}{\sigma_1^2}\boldsymbol{I} Σ−1=σ121I,故可以写为:
N ( f 1 ( z 1 ( i ) , θ 1 ) , σ 1 2 I ) = − log [ ( 2 π ) n / 2 ∣ Σ ∣ 1 / 2 ] − 1 2 σ 1 2 ∥ x ( i ) − f 1 ( z 1 ( i ) , θ 1 ) ∥ 2 \begin{align}N(f_1(\boldsymbol{z}_1^{(i)},\theta_1),\sigma_1^2\boldsymbol{I})=-\log\left[(2\pi)^{n/2}|\Sigma|^{1/2}\right]-\frac{1}{2\sigma_1^2}\left\|\boldsymbol{x}^{(i)}-f_1(\boldsymbol{z}_1^{(i)},\theta_1)\right\|^2\end{align} N(f1(z1(i),θ1),σ12I)=−log[(2π)n/2∣Σ∣1/2]−2σ121 x(i)−f1(z1(i),θ1) 2
而由(47)(48),可知
1 2 σ 1 2 ∥ x ( i ) − f 1 ( z 1 ( i ) , θ 1 ) ∥ 2 = 1 2 σ 1 2 ∥ 1 α 1 z 1 ( i ) − 1 − α 1 α 1 ϵ 1 ( i ) − 1 1 − β 1 z 1 ( i ) + β 1 ( 1 − β 1 ) ( 1 − α 1 ) g 1 [ z 1 ( i ) , θ 1 ] ∥ 2 = β 1 2 2 σ 1 2 ( 1 − β 1 ) ( 1 − α 1 ) ∥ g 1 [ z 1 ( i ) , θ 1 ] − ϵ 1 ( i ) ∥ 2 \begin{align}\frac{1}{2\sigma_1^2}\left\|\boldsymbol{x}^{(i)}-f_1(\boldsymbol{z}_1^{(i)},\theta_1)\right\|^2&=\frac{1}{2\sigma_1^2}\left\|\frac{1}{\sqrt{\alpha_1}}\boldsymbol{z}_1^{(i)}-\frac{\sqrt{1-\alpha_1}}{\sqrt{\alpha_1}}\boldsymbol{\epsilon}_1^{(i)}- \frac{1}{\sqrt{1-\beta_1}}\boldsymbol{z}_1^{(i)}+\frac{\beta_1}{\sqrt{(1-\beta_1)(1-\alpha_1)}}g_1[\boldsymbol{z}_1^{(i)},\theta_1]\right\|^2\\&=\frac{\beta_1^2}{2\sigma_1^2(1-\beta_1)(1-\alpha_1)}\left\|g_1[\boldsymbol{z}_1^{(i)},\theta_1]-\boldsymbol{\epsilon}_1^{(i)}\right\|^2\end{align} 2σ121 x(i)−f1(z1(i),θ1) 2=2σ121 α11z1(i)−α11−α1ϵ1(i)−1−β11z1(i)+(1−β1)(1−α1)β1g1[z1(i),θ1] 2=2σ12(1−β1)(1−α1)β12 g1[z1(i),θ1]−ϵ1(i) 2
至此,目标可以进一步简化为
θ ˉ 1 : T = arg min ∑ i = 1 n ∑ t = 1 T β t 2 2 σ t 2 ( 1 − β t ) ( 1 − α t ) ∥ g t [ z t ( i ) , θ t ] − ϵ t ( i ) ∥ 2 \begin{equation}\bar{\theta}_{1:T}=\arg\min \sum_{i=1}^{n}\sum_{t=1}^{T}\frac{\beta_t^2}{2\sigma_t^2(1-\beta_t)(1-\alpha_t)}\left\|g_t[\boldsymbol{z}_t^{(i)},\theta_t]-\boldsymbol{\epsilon}_t^{(i)}\right\|^2 \end{equation} θˉ1:T=argmini=1∑nt=1∑T2σt2(1−βt)(1−αt)βt2 gt[zt(i),θt]−ϵt(i) 2
其中 − log [ ( 2 π ) n / 2 ∣ Σ ∣ 1 / 2 ] -\log\left[(2\pi)^{n/2}|\Sigma|^{1/2}\right] −log[(2π)n/2∣Σ∣1/2]为一个常数,已从目标中省去。另外,在实际的实验中发现,优化目标的系数并不重要,可以在训练中设置为1以简化目标。
训练
根据编码器和训练目标,可以得到训练算法如下:
对所有的观测数据 x ( i ) \boldsymbol{x}^{(i)} x(i), z 0 ( i ) = x ( i ) , i = 1 , ⋯ , n \boldsymbol{z}_{0}^{(i)}=\boldsymbol{x}^{(i)}, i=1,\cdots,n z0(i)=x(i),i=1,⋯,n,loss=0,循环执行:
——循环 t = 1 , 2 , ⋯ , T t=1,2,\cdots,T t=1,2,⋯,T,执行:
————计算 z t ( i ) \boldsymbol{z}_t^{(i)} zt(i): z t ( i ) = α t x ( i ) + 1 − α t ϵ t ( i ) \boldsymbol{z}_t^{(i)}=\sqrt{\alpha_t}\boldsymbol{x}^{(i)}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon}_t^{(i)} zt(i)=αtx(i)+1−αtϵt(i)。其中 ϵ ( i ) ∼ N ( 0 , I ) \boldsymbol{\epsilon}^{(i)}\sim N(\boldsymbol{0},\boldsymbol{I}) ϵ(i)∼N(0,I)。
————训练 g t [ z t ( i ) , θ t ] g_t[\boldsymbol{z}_t^{(i)},\theta_t] gt[zt(i),θt],并累积损失loss+= β t 2 2 σ t 2 ( 1 − β t ) ( 1 − α t ) ∥ g t [ z t ( i ) , θ t ] − ϵ t ( i ) ∥ 2 \frac{\beta_t^2}{2\sigma_t^2(1-\beta_t)(1-\alpha_t)}\left\|g_t[\boldsymbol{z}_t^{(i)},\theta_t]-\boldsymbol{\epsilon}_t^{(i)}\right\|^2 2σt2(1−βt)(1−αt)βt2 gt[zt(i),θt]−ϵt(i) 2
执行反向传播和梯度下降,更新 θ t \theta_t θt。
采样
根据解码器,可以得到采样算法:
从 N ( 0 , I ) N(\boldsymbol{0},\boldsymbol{I}) N(0,I)采样出 z T \boldsymbol{z}_T zT
循环 t = T − 1 , T − 2 , ⋯ , 1 t=T-1,T-2,\cdots,1 t=T−1,T−2,⋯,1,执行:
——将 z t \boldsymbol{z}_t zt输入神经网络 g t g_t gt,获取 g t [ z t , θ t ] g_t[\boldsymbol{z}_t,\theta_t] gt[zt,θt]
——估计 p ( z t − 1 ∣ z t , θ t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_t) p(zt−1∣zt,θt)的均值 μ \boldsymbol{\mu} μ: μ = 1 1 − β t z t − β t ( 1 − β t ) ( 1 − α t ) g t [ z t , θ t ] \boldsymbol{\mu}=\frac{1}{\sqrt{1-\beta_t}}\boldsymbol{z}_t-\frac{\beta_t}{\sqrt{(1-\beta_t)(1-\alpha_t)}}g_t[\boldsymbol{z}_t,\theta_t] μ=1−βt1zt−(1−βt)(1−αt)βtgt[zt,θt]
——如果 t > 1 t>1 t>1:
————从 p ( z t − 1 ∣ z t ∗ , θ t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t}^*,\theta_t) p(zt−1∣zt∗,θt)中采样出 z t − 1 \boldsymbol{z}_{t-1} zt−1: z t − 1 = μ + σ t ϵ \boldsymbol{z}_{t-1}=\boldsymbol{\mu}+\sigma_t\boldsymbol{\epsilon} zt−1=μ+σtϵ,其中 ϵ ∼ N ( 0 , I ) \boldsymbol{\epsilon}\sim N(\boldsymbol{0},\boldsymbol{I}) ϵ∼N(0,I)
——否则:
————从 p ( x ∣ z 1 , θ 1 ) p(\boldsymbol{x}|\boldsymbol{z}_{1},\theta_1) p(x∣z1,θ1)中采样出 x \boldsymbol{x} x: x = μ \boldsymbol{x}=\boldsymbol{\mu} x=μ
在实际训练时,并不会真的训练多个对应不同时刻的神经网络 g t [ z t , θ t ] , t = 1 , 2 , ⋯ , T g_t[\boldsymbol{z}_t,\theta_t], t=1,2,\cdots,T gt[zt,θt],t=1,2,⋯,T,而是用一个加上时间信息的神经网络或 g [ z t , t , θ t ] g[\boldsymbol{z}_t,t,\theta_t] g[zt,t,θt]来代替。
这篇关于扩散模型详细推导过程——训练与采样的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!