DDPM扩散模型公式推理----损失函数

DDPM扩散模型公式推理----损失函数

2.3 损失函数推导

我们用极大似然估计的思想构建损失函数:
L = − log ⁡ p θ ( x 0 ) \mathcal{L}=-\log p_{\theta}\left(x_{0}\right) L=logpθ(x0)
即逆扩散网络的参数 θ \theta θ 使刚开始采样的数据 x 0 x_0 x0 出现的概率最大。

接下来需要对上式进行变形,用到了一些ELBO和VAE的内容。

2.3.1 ELBO

已知极大似然估计 p θ ( x 0 ) p_{\theta}\left(x_{0}\right) pθ(x0) 和观测值 x 0 x_0 x0 ,以及由观测值扩散得到的 x 1 : T x_{1:T} x1:T ,由边缘概率分布公式:
p θ ( x 0 ) = ∫ p θ ( x 0 : T ) d x 1 : T p_{\theta }(\boldsymbol{x}_{0})=\int p_{\theta}(x_{0:T}) d x_{1:T} pθ(x0)=pθ(x0:T)dx1:T
因此
log ⁡ p θ ( x 0 ) = log ⁡ ∫ p θ ( x 0 : T ) d x 1 : T = log ⁡ ∫ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) q ϕ ( x 1 : T ∣ x 0 ) d x 1 : T = log ⁡ E q ϕ ( x 1 : T ∣ x 0 ) [ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \begin{aligned} \log p_{\theta }(\boldsymbol{x}_{0}) & =\log \int p_{\theta}(\boldsymbol{x_{0:T}}) d \boldsymbol{x_{1:T}} \\ & =\log \int \frac{p_{\theta}(\boldsymbol{x_{0:T}}) q_{\phi}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})} d \boldsymbol{\boldsymbol{x_{1:T}}} \\ & =\log \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{x_{1:T}} \mid \boldsymbol{x_0})}\left[\frac{p_{\theta }(\boldsymbol{\boldsymbol{x_{0:T}}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \\ & \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{aligned} logpθ(x0)=logpθ(x0:T)dx1:T=logqϕ(x1:Tx0)pθ(x0:T)qϕ(x1:Tx0)dx1:T=logEqϕ(x1:Tx0)[qϕ(x1:Tx0)pθ(x0:T)]Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)]
最后一步是由琴声不等式 ( J e n s e n ′ s I n e q u a l i t y ) (Jensen's Inequality) (JensensInequality) 得来的。
这样看来不是很直观,还有一种推导方式更简单一些:
log ⁡ p ( x ) = log ⁡ p ( x ) ∫ q ϕ ( z ∣ x ) d z = ∫ q ϕ ( z ∣ x ) ( log ⁡ p ( x ) ) d z = E q ϕ ( z ∣ x ) [ log ⁡ p ( x ) ] = E q ϕ ( z ∣ x ) [ log ⁡ p ( x , z ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log ⁡ p ( x , z ) q ϕ ( z ∣ x ) p ( z ∣ x ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log ⁡ p ( x , z ) q ϕ ( z ∣ x ) ] + E q ϕ ( z ∣ x ) [ log ⁡ q ϕ ( z ∣ x ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log ⁡ p ( x , z ) q ϕ ( z ∣ x ) ] + D K L ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ) ) ≥ E q ϕ ( z ∣ x ) [ log ⁡ p ( x , z ) q ϕ ( z ∣ x ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \begin{aligned} \log p(\boldsymbol{x}) & =\log p(\boldsymbol{x}) \int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) d z \\ & =\int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})(\log p(\boldsymbol{x})) d z \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}[\log p(\boldsymbol{x})] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{p(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}{p(\boldsymbol{z} \mid \boldsymbol{x}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}{p(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+D_{\mathrm{KL}}\left(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) \| p(\boldsymbol{z} \mid \boldsymbol{x})\right) \\ & \geq \mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{aligned} logp(x)=logp(x)qϕ(zx)dz=qϕ(zx)(logp(x))dz=Eqϕ(zx)[logp(x)]=Eqϕ(zx)[logp(zx)p(x,z)]=Eqϕ(zx)[logp(zx)qϕ(zx)p(x,z)qϕ(zx)]=Eqϕ(zx)[logqϕ(zx)p(x,z)]+Eqϕ(zx)[logp(zx)qϕ(zx)]=Eqϕ(zx)[logqϕ(zx)p(x,z)]+DKL(qϕ(zx)p(zx))Eqϕ(zx)[logqϕ(zx)p(x,z)]=Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)]
这里的 z z z 代表 x 1 : T x_{1:T} x1:T x x x 代表 x 0 x_0 x0。中间用到了贝叶斯公式。

这里两种方法推导的结论都为:
log ⁡ p θ ( x 0 ) ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \log p_{\theta }(\boldsymbol{x}_{0}) \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] logpθ(x0)Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)]

其中 E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)] 就是ELBO ( E v i d e n c e L o w e r B o u n d ) (Evidence Lower Bound) (EvidenceLowerBound),即变分下界。

我们要使损失函数 − log ⁡ p θ ( x 0 ) -\log p_{\theta}\left(x_{0}\right) logpθ(x0) 最小,就是另ELBO E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)] 最大。

我们令 L V L B = − E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log ⁡ q ϕ ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] L_{VLB} = -\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}{p_{\theta }(\boldsymbol{x_{0:T}})}\right] LVLB=Eqϕ(x1:Tx0)[logqϕ(x1:Tx0)pθ(x0:T)]=Eqϕ(x1:Tx0)[logpθ(x0:T)qϕ(x1:Tx0)],接下来就转换为对该损失函数求解。再进一步分解:
L V L B = E q ( x 0 : T ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = E q [ log ⁡ ∏ t = 1 T q ( x t ∣ x t − 1 ) p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 1 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ⋅ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + ∑ t = 2 T log ⁡ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x T ∣ x 0 ) q ( x 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ log ⁡ q ( x T ∣ x 0 ) p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) − log ⁡ p θ ( x 0 ∣ x 1 ) ] = E q [ D K L ( q ( x T ∣ x 0 ) ∥ p θ ( x T ) ) ⏟ L T + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ⏟ L t − 1 − log ⁡ p θ ( x 0 ∣ x 1 ) ⏟ L 0 ] \begin{aligned} L_{\mathrm{VLB}} & =\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0: T}\right)}\right] \\ & =\mathbb{E}_{q}\left[\log \frac{\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=1}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)} \cdot \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}\right)+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)\right] \\ & =\mathbb{E}_{q}[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{T}\right)\right)}_{L_{T}}+\sum_{t=2}^{T} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)}_{L_{t-1}}-\underbrace{\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}] \end{aligned} LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:Tx0)]=Eq[logpθ(xT)t=1Tpθ(xt1xt)t=1Tq(xtxt1)]=Eq[logpθ(xT)+t=1Tlogpθ(xt1xt)q(xtxt1)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xtxt1)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlog(pθ(xt1xt)q(xt1xt,x0)q(xt1x0)q(xtx0))+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+t=2Tlogq(xt1x0)q(xtx0)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+logq(x1x0)q(xTx0)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)q(xTx0)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)logpθ(x0x1)]=Eq[LT DKL(q(xTx0)pθ(xT))+t=2TLt1 DKL(q(xt1xt,x0)pθ(xt1xt))L0 logpθ(x0x1)]

这里刚开始的下脚标从 q ϕ ( x 1 : T ∣ x 0 ) q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0}) qϕ(x1:Tx0) 换为了 q ( x 0 : T ) q\left(\mathbf{x}_{0: T}\right) q(x0:T),我认为 x 0 x_0 x0 为已知,所以这两个式子是等价的。

中间还有容易混淆的一点是, q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right) q(x1:Tx0)代表前向传播的分布, q ( x t − 1 ∣ x t ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) q(xt1xt) 代表逆扩散过程的真实分布,这里的 q q q 没有正向或反向的含义,只是代表概率和分布,可以理解为概率论中的概率 P P P p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt1xt) 代表我们要求解的逆扩散分布。

接下来我们对 L T , L t − 1 , L 0 L_{T}, L_{t-1}, L_{0} LT,Lt1,L0 这三种情况进行分类讨论:

2.3.2 L T L_{T} LT

q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1:T} \mid \mathbf{x}_{0}\right) q(x1:Tx0) 代表前向扩散过程,没有可学习参数; p θ ( x T ) p_{\theta}\left(\mathbf{x}_{T}\right) pθ(xT) 中的 x T x_T xT 为服从标准高斯分布的噪声, p θ p_{\theta} pθ 为逆扩散过程,对逆扩散过程而言, x T x_T xT 为已知的,所以这一项 L T L_{T} LT 可以当作常量。

2.3.3 L t − 1 L_{t-1} Lt1

L t − 1 L_{t-1} Lt1 可以看出是真实的逆扩散分布 q ( x t − 1 ∣ x t , x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) q(xt1xt,x0) 和我们要求的逆扩散分布 p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt1xt) 的KL散度。

  1. 真实分布 q ( x t − 1 ∣ x t , x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) q(xt1xt,x0) 的均值和方差我们已经求出:
    μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ε t ) , β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \tilde{\mu}_{t}=\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \varepsilon_{t}\right), \tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \cdot \beta_{t} μ~t=αt 1(xt1αˉt 1αtεt),β~t=1αˉt1αˉt1βt
  2. 第二个分布 p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt1xt) 是我们希望拟合的目标分布,也是一个高斯分布,均值用网络估计,方差被设置为了和 β t \beta_t βt 有关:
    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right), \boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

所以为了使这两个分布接近,方差我们可以不用管,只需让两个分布的均值的距离最小,我们用二范数来表示:
L t = E q [ ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ ∥ 1 α t ( x t ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) − μ θ ( x t ( x 0 , ϵ ) , t ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) \begin{aligned} L_{t} & =\mathbb{E}_{q}\left[\left\|\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right] \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right), t\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \end{aligned} Lt=Eq[μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[ αt 1(xt(x0,ϵ)1αˉt βtϵ)μθ(xt(x0,ϵ),t) 2]ϵN(0,1)
在该公式中可以观察到,我们需要用 μ θ ( x t ( x 0 , ϵ ) , t ) \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right), t\right) μθ(xt(x0,ϵ),t) 去拟合 1 α t ( x t ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right) αt 1(xt(x0,ϵ)1αˉt βtϵ) ,所以干脆令:
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right) μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))
也就是直接用神经网络 ϵ θ ( x t , t ) \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right) ϵθ(xt,t) 去预测噪声 ϵ \epsilon ϵ。然后把预测出来的噪声带入到定义好的表达式去计算出预测的均值。

所以损失函数变为了:
L t = E x 0 , ϵ [ ∥ 1 α t ( x t − β t 1 − α ˉ t ϵ ) − 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] ϵ ∼ N ( 0 , 1 )  把常数项系数都扔了, 作者说这样更好训练  = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] , ϵ ∼ N ( 0 , 1 ) \begin{aligned} L_{t} & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \quad \text { 把常数项系数都扔了, 作者说这样更好训练 } \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right)\right\|^{2}\right], \quad \epsilon \sim \mathcal{N}(0,1) \end{aligned} Lt=Ex0,ϵ[ αt 1(xt1αˉt βtϵ)αt 1(xt1αˉt βtϵθ(xt,t)) 2]ϵN(0,1)=Ex0,ϵ[ϵϵθ(xt,t)2]ϵN(0,1) 把常数项系数都扔了作者说这样更好训练 =Ex0,ϵ[ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2],ϵN(0,1)
网络的输入是一张和噪声线性组合的图片 x t x_t xt,与其组合的噪声真实值为 ϵ \epsilon ϵ,我们需要用 ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right) ϵθ(αˉt x0+1αˉt ϵ,t) 去拟合这个噪声。

2.3.4 L 0 L_{0} L0

最后的 L 0 = − log ⁡ p θ ( x 0 ∣ x 1 ) L_{0}=-\log p_{\theta}\left(x_{0} \mid x_{1}\right) L0=logpθ(x0x1) 是将最后一步的加噪图像 x 1 x_1 x1 生成去燥图像 x 0 x_0 x0 的极大似然估计,为了生成更好的图像,我们需要对每个像素都运用极大似然估计,使得图像上每个像素值都满足离散的对数似然。

为了达到这个目的,将逆扩散过程中的最后从 x 1 x_1 x1 x 0 x_0 x0 的转换设置为独立的离散计算方式。 即在最后一个转换过程在给定 x 1 x_1 x1 下得到图像 x 0 x_0 x0 满足对数似然,假设像素与像素之间是相互独立的:
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D p θ ( x 0 i ∣ x 1 i ) p_{\theta}\left(x_{0} \mid x_{1}\right)=\prod_{i=1}^{D} p_{\theta}\left(x_{0}^{i} \mid x_{1}^{i}\right) pθ(x0x1)=i=1Dpθ(x0ix1i)
D D D 代表 x x x 的维度,上标 i i i 表示图像中的一个坐标位置。现在的目标是确定给定像素的值可能性有多大,也就是想要知道对应时间步 t = 1 t=1 t=1 下噪声图像 x x x 中相应像素值的分布:
N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθi(x1,1),σ12)
其中 t = 1 t=1 t=1 的像素分布来自多元高斯分布,其对角协方差矩阵允许我们将分布拆分为单变量高斯的乘积:
N ( x ; μ θ ( x 1 , 1 ) , σ 1 2 I ) = ∏ i = 1 D N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}\left(x_{1}, 1\right), \sigma_{1}^{2} \mathbb{I}\right)=\prod_{i=1}^{D} \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθ(x1,1),σ12I)=i=1DN(x;μθi(x1,1),σ12)
现在假设图像已经从0-255的数值之间,经过归一化在[-1,1]的范围内。在 t=0 时给定每个像素的像素值,最后一个时间步 t=1 的转换概率分布 p θ ( x 0 ∣ x 1 ) p_{\theta}\left(x_{0} \mid x_{1}\right) pθ(x0x1) 的值就是每个像素值的乘积。所以:
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D ∫ δ − ( x 0 i ) δ + ( x 0 i ) N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) d x δ + ( x ) = { ∞  if  x = 1 x + 1 255  if  x < 1 δ − ( x ) = { − ∞  if  x = − 1 x − 1 255  if  x > − 1 \begin{aligned} p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right) & =\prod_{i=1}^{D} \int_{\delta_{-}\left(x_{0}^{i}\right)}^{\delta_{+}\left(x_{0}^{i}\right)} \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(\mathbf{x}_{1}, 1\right), \sigma_{1}^{2}\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array}{ll} -\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{aligned} pθ(x0x1)δ+(x)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dx={ x+2551 if x=1 if x<1δ(x)={ x2551 if x=1 if x>1
这个公式来自原论文,这里解析一下它的含义。就是我们要将最后一步的加噪图像 x 1 x_1 x1 拟合去燥图像 x 0 x_0 x0 ,把图像的每一个像素点都设为一个高斯分布,一共 D D D 个像素点。而 x 0 x_0 x0 每个像素点原本的取值范围为 { 0 , 1 , … , 255 } \{0,1, \ldots, 255\} { 0,1,,255} ,经过归一化映射到了 [ − 1 , 1 ] [-1,1] [1,1] 范围内。

现在我们把单独取出一个 x 1 x_1 x1 上的像素点 x 1 i x_1^i x1i,它服从分布 N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθi(x1,1),σ12),需要拟合的目标为 x 0 x_0 x0 上的对应位置像素点 x 0 i x_0^i x0i ,而 x 0 i x_0^i x0i 的取值范围由原本的离散空间 { 0 , 1 , … , 255 } \{0,1, \ldots, 255\} { 0,1,,255} 映射到了连续空间 [ − 1 , 1 ] [-1,1] [1,1],所以每个原本的离散值在连续空间中对应一个区间,而区间映射的公式就是:
δ + ( x ) = { ∞  if  x = 1 x + 1 255  if  x < 1 δ − ( x ) = { − ∞  if  x = − 1 x − 1 255  if  x > − 1 \begin{aligned} \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array}{ll} -\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{aligned} δ+(x)={ x+2551 if x=1 if x<1δ(x)={ x2551 if x=1 if x>1

以上就是对DDPM得扩散和逆扩散过程中涉及到的所有公式的解析和推导,包括损失函数构建部分。

参考文献和博客

Understanding Diffusion Models: A Unified Perspective
Denoising Diffusion Probabilistic Models
https://yinglinzheng.netlify.app/diffusion-model-tutorial
https://zhuanlan.zhihu.com/p/549623622

猜你喜欢

转载自blog.csdn.net/weixin_45453121/article/details/131223653
今日推荐