DDPM扩散模型公式推理----损失函数
目录
2.3 损失函数推导
我们用极大似然估计的思想构建损失函数:
L = − log p θ ( x 0 ) \mathcal{L}=-\log p_{\theta}\left(x_{0}\right) L=−logpθ(x0)
即逆扩散网络的参数 θ \theta θ 使刚开始采样的数据 x 0 x_0 x0 出现的概率最大。
接下来需要对上式进行变形,用到了一些ELBO和VAE的内容。
2.3.1 ELBO
已知极大似然估计 p θ ( x 0 ) p_{\theta}\left(x_{0}\right) pθ(x0) 和观测值 x 0 x_0 x0 ,以及由观测值扩散得到的 x 1 : T x_{1:T} x1:T ,由边缘概率分布公式:
p θ ( x 0 ) = ∫ p θ ( x 0 : T ) d x 1 : T p_{\theta }(\boldsymbol{x}_{0})=\int p_{\theta}(x_{0:T}) d x_{1:T} pθ(x0)=∫pθ(x0:T)dx1:T
因此
log p θ ( x 0 ) = log ∫ p θ ( x 0 : T ) d x 1 : T = log ∫ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) q ϕ ( x 1 : T ∣ x 0 ) d x 1 : T = log E q ϕ ( x 1 : T ∣ x 0 ) [ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \begin{aligned} \log p_{\theta }(\boldsymbol{x}_{0}) & =\log \int p_{\theta}(\boldsymbol{x_{0:T}}) d \boldsymbol{x_{1:T}} \\ & =\log \int \frac{p_{\theta}(\boldsymbol{x_{0:T}}) q_{\phi}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})} d \boldsymbol{\boldsymbol{x_{1:T}}} \\ & =\log \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{x_{1:T}} \mid \boldsymbol{x_0})}\left[\frac{p_{\theta }(\boldsymbol{\boldsymbol{x_{0:T}}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \\ & \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{aligned} logpθ(x0)=log∫pθ(x0:T)dx1:T=log∫qϕ(x1:T∣x0)pθ(x0:T)qϕ(x1:T∣x0)dx1:T=logEqϕ(x1:T∣x0)[qϕ(x1:T∣x0)pθ(x0:T)]≥Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)]
最后一步是由琴声不等式 ( J e n s e n ′ s I n e q u a l i t y ) (Jensen's Inequality) (Jensen′sInequality) 得来的。
这样看来不是很直观,还有一种推导方式更简单一些:
log p ( x ) = log p ( x ) ∫ q ϕ ( z ∣ x ) d z = ∫ q ϕ ( z ∣ x ) ( log p ( x ) ) d z = E q ϕ ( z ∣ x ) [ log p ( x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) p ( z ∣ x ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] + D K L ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ) ) ≥ E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \begin{aligned} \log p(\boldsymbol{x}) & =\log p(\boldsymbol{x}) \int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) d z \\ & =\int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})(\log p(\boldsymbol{x})) d z \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}[\log p(\boldsymbol{x})] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{p(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}{p(\boldsymbol{z} \mid \boldsymbol{x}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}{p(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+D_{\mathrm{KL}}\left(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) \| p(\boldsymbol{z} \mid \boldsymbol{x})\right) \\ & \geq \mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{aligned} logp(x)=logp(x)∫qϕ(z∣x)dz=∫qϕ(z∣x)(logp(x))dz=Eqϕ(z∣x)[logp(x)]=Eqϕ(z∣x)[logp(z∣x)p(x,z)]=Eqϕ(z∣x)[logp(z∣x)qϕ(z∣x)p(x,z)qϕ(z∣x)]=Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]+Eqϕ(z∣x)[logp(z∣x)qϕ(z∣x)]=Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]+DKL(qϕ(z∣x)∥p(z∣x))≥Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]=Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)]
这里的 z z z 代表 x 1 : T x_{1:T} x1:T, x x x 代表 x 0 x_0 x0。中间用到了贝叶斯公式。
这里两种方法推导的结论都为:
log p θ ( x 0 ) ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \log p_{\theta }(\boldsymbol{x}_{0}) \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] logpθ(x0)≥Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)]
其中 E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)] 就是ELBO ( E v i d e n c e L o w e r B o u n d ) (Evidence Lower Bound) (EvidenceLowerBound),即变分下界。
我们要使损失函数 − log p θ ( x 0 ) -\log p_{\theta}\left(x_{0}\right) −logpθ(x0) 最小,就是另ELBO E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)] 最大。
我们令 L V L B = − E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log q ϕ ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] L_{VLB} = -\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}{p_{\theta }(\boldsymbol{x_{0:T}})}\right] LVLB=−Eqϕ(x1:T∣x0)[logqϕ(x1:T∣x0)pθ(x0:T)]=Eqϕ(x1:T∣x0)[logpθ(x0:T)qϕ(x1:T∣x0)],接下来就转换为对该损失函数求解。再进一步分解:
L V L B = E q ( x 0 : T ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = E q [ log ∏ t = 1 T q ( x t ∣ x t − 1 ) p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q [ − log p θ ( x T ) + ∑ t = 1 T log q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ⋅ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + ∑ t = 2 T log q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log q ( x T ∣ x 0 ) q ( x 1 ∣ x 0 ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ log q ( x T ∣ x 0 ) p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) − log p θ ( x 0 ∣ x 1 ) ] = E q [ D K L ( q ( x T ∣ x 0 ) ∥ p θ ( x T ) ) ⏟ L T + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ⏟ L t − 1 − log p θ ( x 0 ∣ x 1 ) ⏟ L 0 ] \begin{aligned} L_{\mathrm{VLB}} & =\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0: T}\right)}\right] \\ & =\mathbb{E}_{q}\left[\log \frac{\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=1}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)} \cdot \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}\right)+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}\right] \\ & =\mathbb{E}_{q}\left[\log \frac{q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)}-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)\right] \\ & =\mathbb{E}_{q}[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{T}\right)\right)}_{L_{T}}+\sum_{t=2}^{T} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)}_{L_{t-1}}-\underbrace{\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}] \end{aligned} LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]=Eq[logpθ(xT)∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1)]=Eq[−logpθ(xT)+t=1∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0)⋅q(xt−1∣x0)q(xt∣x0))+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+t=2∑Tlogq(xt−1∣x0)q(xt∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−logpθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+logq(x1∣x0)q(xT∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[logpθ(xT)q(xT∣x0)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)−logpθ(x0∣x1)]=Eq[LT
DKL(q(xT∣x0)∥pθ(xT))+t=2∑TLt−1
DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))−L0
logpθ(x0∣x1)]
这里刚开始的下脚标从 q ϕ ( x 1 : T ∣ x 0 ) q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0}) qϕ(x1:T∣x0) 换为了 q ( x 0 : T ) q\left(\mathbf{x}_{0: T}\right) q(x0:T),我认为 x 0 x_0 x0 为已知,所以这两个式子是等价的。
中间还有容易混淆的一点是, q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right) q(x1:T∣x0)代表前向传播的分布, q ( x t − 1 ∣ x t ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) q(xt−1∣xt) 代表逆扩散过程的真实分布,这里的 q q q 没有正向或反向的含义,只是代表概率和分布,可以理解为概率论中的概率 P P P。 p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt−1∣xt) 代表我们要求解的逆扩散分布。
接下来我们对 L T , L t − 1 , L 0 L_{T}, L_{t-1}, L_{0} LT,Lt−1,L0 这三种情况进行分类讨论:
2.3.2 L T L_{T} LT
q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1:T} \mid \mathbf{x}_{0}\right) q(x1:T∣x0) 代表前向扩散过程,没有可学习参数; p θ ( x T ) p_{\theta}\left(\mathbf{x}_{T}\right) pθ(xT) 中的 x T x_T xT 为服从标准高斯分布的噪声, p θ p_{\theta} pθ 为逆扩散过程,对逆扩散过程而言, x T x_T xT 为已知的,所以这一项 L T L_{T} LT 可以当作常量。
2.3.3 L t − 1 L_{t-1} Lt−1
而 L t − 1 L_{t-1} Lt−1 可以看出是真实的逆扩散分布 q ( x t − 1 ∣ x t , x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) q(xt−1∣xt,x0) 和我们要求的逆扩散分布 p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt−1∣xt) 的KL散度。
- 真实分布 q ( x t − 1 ∣ x t , x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) q(xt−1∣xt,x0) 的均值和方差我们已经求出:
μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ε t ) , β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \tilde{\mu}_{t}=\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \varepsilon_{t}\right), \tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \cdot \beta_{t} μ~t=αt1(xt−1−αˉt1−αtεt),β~t=1−αˉt1−αˉt−1⋅βt - 第二个分布 p θ ( x t − 1 ∣ x t ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) pθ(xt−1∣xt) 是我们希望拟合的目标分布,也是一个高斯分布,均值用网络估计,方差被设置为了和 β t \beta_t βt 有关:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right), \boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
所以为了使这两个分布接近,方差我们可以不用管,只需让两个分布的均值的距离最小,我们用二范数来表示:
L t = E q [ ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ ∥ 1 α t ( x t ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) − μ θ ( x t ( x 0 , ϵ ) , t ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) \begin{aligned} L_{t} & =\mathbb{E}_{q}\left[\left\|\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right] \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right), t\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \end{aligned} Lt=Eq[∥μ~t(xt,x0)−μθ(xt,t)∥2]=Ex0,ϵ[
αt1(xt(x0,ϵ)−1−αˉtβtϵ)−μθ(xt(x0,ϵ),t)
2]ϵ∼N(0,1)
在该公式中可以观察到,我们需要用 μ θ ( x t ( x 0 , ϵ ) , t ) \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right), t\right) μθ(xt(x0,ϵ),t) 去拟合 1 α t ( x t ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \epsilon\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right) αt1(xt(x0,ϵ)−1−αˉtβtϵ) ,所以干脆令:
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right) μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
也就是直接用神经网络 ϵ θ ( x t , t ) \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right) ϵθ(xt,t) 去预测噪声 ϵ \epsilon ϵ。然后把预测出来的噪声带入到定义好的表达式去计算出预测的均值。
所以损失函数变为了:
L t = E x 0 , ϵ [ ∥ 1 α t ( x t − β t 1 − α ˉ t ϵ ) − 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) 把常数项系数都扔了, 作者说这样更好训练 = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] , ϵ ∼ N ( 0 , 1 ) \begin{aligned} L_{t} & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right] \quad \epsilon \sim \mathcal{N}(0,1) \quad \text { 把常数项系数都扔了, 作者说这样更好训练 } \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right)\right\|^{2}\right], \quad \epsilon \sim \mathcal{N}(0,1) \end{aligned} Lt=Ex0,ϵ[
αt1(xt−1−αˉtβtϵ)−αt1(xt−1−αˉtβtϵθ(xt,t))
2]ϵ∼N(0,1)=Ex0,ϵ[∥ϵ−ϵθ(xt,t)∥2]ϵ∼N(0,1) 把常数项系数都扔了, 作者说这样更好训练 =Ex0,ϵ[
ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)
2],ϵ∼N(0,1)
网络的输入是一张和噪声线性组合的图片 x t x_t xt,与其组合的噪声真实值为 ϵ \epsilon ϵ,我们需要用 ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right) ϵθ(αˉtx0+1−αˉtϵ,t) 去拟合这个噪声。
2.3.4 L 0 L_{0} L0
最后的 L 0 = − log p θ ( x 0 ∣ x 1 ) L_{0}=-\log p_{\theta}\left(x_{0} \mid x_{1}\right) L0=−logpθ(x0∣x1) 是将最后一步的加噪图像 x 1 x_1 x1 生成去燥图像 x 0 x_0 x0 的极大似然估计,为了生成更好的图像,我们需要对每个像素都运用极大似然估计,使得图像上每个像素值都满足离散的对数似然。
为了达到这个目的,将逆扩散过程中的最后从 x 1 x_1 x1 到 x 0 x_0 x0 的转换设置为独立的离散计算方式。 即在最后一个转换过程在给定 x 1 x_1 x1 下得到图像 x 0 x_0 x0 满足对数似然,假设像素与像素之间是相互独立的:
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D p θ ( x 0 i ∣ x 1 i ) p_{\theta}\left(x_{0} \mid x_{1}\right)=\prod_{i=1}^{D} p_{\theta}\left(x_{0}^{i} \mid x_{1}^{i}\right) pθ(x0∣x1)=i=1∏Dpθ(x0i∣x1i)
D D D 代表 x x x 的维度,上标 i i i 表示图像中的一个坐标位置。现在的目标是确定给定像素的值可能性有多大,也就是想要知道对应时间步 t = 1 t=1 t=1 下噪声图像 x x x 中相应像素值的分布:
N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθi(x1,1),σ12)
其中 t = 1 t=1 t=1 的像素分布来自多元高斯分布,其对角协方差矩阵允许我们将分布拆分为单变量高斯的乘积:
N ( x ; μ θ ( x 1 , 1 ) , σ 1 2 I ) = ∏ i = 1 D N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}\left(x_{1}, 1\right), \sigma_{1}^{2} \mathbb{I}\right)=\prod_{i=1}^{D} \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθ(x1,1),σ12I)=i=1∏DN(x;μθi(x1,1),σ12)
现在假设图像已经从0-255的数值之间,经过归一化在[-1,1]的范围内。在 t=0 时给定每个像素的像素值,最后一个时间步 t=1 的转换概率分布 p θ ( x 0 ∣ x 1 ) p_{\theta}\left(x_{0} \mid x_{1}\right) pθ(x0∣x1) 的值就是每个像素值的乘积。所以:
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D ∫ δ − ( x 0 i ) δ + ( x 0 i ) N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) d x δ + ( x ) = { ∞ if x = 1 x + 1 255 if x < 1 δ − ( x ) = { − ∞ if x = − 1 x − 1 255 if x > − 1 \begin{aligned} p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right) & =\prod_{i=1}^{D} \int_{\delta_{-}\left(x_{0}^{i}\right)}^{\delta_{+}\left(x_{0}^{i}\right)} \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(\mathbf{x}_{1}, 1\right), \sigma_{1}^{2}\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array}{ll} -\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{aligned} pθ(x0∣x1)δ+(x)=i=1∏D∫δ−(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dx={
∞x+2551 if x=1 if x<1δ−(x)={
−∞x−2551 if x=−1 if x>−1
这个公式来自原论文,这里解析一下它的含义。就是我们要将最后一步的加噪图像 x 1 x_1 x1 拟合去燥图像 x 0 x_0 x0 ,把图像的每一个像素点都设为一个高斯分布,一共 D D D 个像素点。而 x 0 x_0 x0 每个像素点原本的取值范围为 { 0 , 1 , … , 255 } \{0,1, \ldots, 255\} {
0,1,…,255} ,经过归一化映射到了 [ − 1 , 1 ] [-1,1] [−1,1] 范围内。
现在我们把单独取出一个 x 1 x_1 x1 上的像素点 x 1 i x_1^i x1i,它服从分布 N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\right), \sigma_{1}^{2}\right) N(x;μθi(x1,1),σ12),需要拟合的目标为 x 0 x_0 x0 上的对应位置像素点 x 0 i x_0^i x0i ,而 x 0 i x_0^i x0i 的取值范围由原本的离散空间 { 0 , 1 , … , 255 } \{0,1, \ldots, 255\} {
0,1,…,255} 映射到了连续空间 [ − 1 , 1 ] [-1,1] [−1,1],所以每个原本的离散值在连续空间中对应一个区间,而区间映射的公式就是:
δ + ( x ) = { ∞ if x = 1 x + 1 255 if x < 1 δ − ( x ) = { − ∞ if x = − 1 x − 1 255 if x > − 1 \begin{aligned} \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array}{ll} -\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{aligned} δ+(x)={
∞x+2551 if x=1 if x<1δ−(x)={
−∞x−2551 if x=−1 if x>−1
以上就是对DDPM得扩散和逆扩散过程中涉及到的所有公式的解析和推导,包括损失函数构建部分。
参考文献和博客
Understanding Diffusion Models: A Unified Perspective
Denoising Diffusion Probabilistic Models
https://yinglinzheng.netlify.app/diffusion-model-tutorial
https://zhuanlan.zhihu.com/p/549623622