DDPM 拡散モデルの公式推論 ----損失関数
目次
2.3 損失関数の導出
最尤推定の考え方を使用して損失関数を構築します。
L = − log p θ ( x 0 ) \mathcal{L}=-\log p_{\theta}\left(x_{0}\right)L=−ログ_p私( ×0)
は逆拡散ネットワークのθ \thetaθ はデータをx 0 にします x_0 はサンプリングを開始したばかりですバツ0最も起こりやすい。
次に、ELBO および VAE コンテンツを使用して、上記の式を変換する必要があります。
2.3.1 ELBA
既知の最尤推定値p θ ( x 0 ) p_{\theta}\left(x_{0}\right)p私( ×0)と観察x 0 x_0バツ0、および観測値の拡散によって得られる x 1 : T x_{1:T}バツ1 : T、周辺確率分布の公式による:
p θ ( x 0 ) = ∫ p θ ( x 0 : T ) dx 1 : T p_{\theta }(\boldsymbol{x}_{0})=\int p_{\シータ }(x_{0:T}) d x_{1:T}p私( ×0)=∫p私( ×0 : T) dx _1 : T
したがって
log p θ ( x 0 ) = log ∫ p θ ( x 0 : T ) dx 1 : T = log ∫ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) q ϕ ( x 1 : T ∣ x 0 ) dx 1 : T = log E q ϕ ( x 1 : T ∣ x 0 ) [ p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \begin{aligned} \log p_{\theta }(\boldsymbol{ x}_{0}) & =\log \int p_{\theta}(\boldsymbol{x_{0:T}}) d \boldsymbol{x_{1:T}} \\ & =\log \int \ frac{p_{\theta}(\boldsymbol{x_{0:T}}) q_{\phi}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}{q_ {\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})} d \boldsymbol{\boldsymbol{x_{1:T}}} \\ & =\log \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{x_{1:T}} \mid \boldsymbol{x_0})}\left[\frac{p_{\theta } (\boldsymbol{\boldsymbol{x_{0:T}}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \\ & \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T} }} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol {\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{整列}ログ_p私( ×0)=ログ_∫p私( ×0 : T) dx _1 : T=ログ_∫qϕ( ×1 : T∣バツ0)p私( ×0 : T) qϕ( ×1 : T∣バツ0)dx _1 : T=ログ_Eqϕ( ×1 : T∣ x0)[qϕ( ×1 : T∣バツ0)p私( ×0 : T)】≥Eqϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣バツ0)p私( ×0 : T)]
最後のステップは、ピアノの音の不等式(ジェンセンの I 不等式) (ジェンセンの不等式)によって決まります。(ジェンセン_ _ _ _は不平等です)。_________
これはあまり直感的ではないようですが、もっと簡単な別の導出方法があります。
log p ( x ) = log p ( x ) ∫ q ϕ ( z ∣ x ) dz = ∫ q ϕ ( z ∣ x ) ( log p ( x ) ) dz = E q ϕ ( z ∣ x ) [ log p ( x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) p ( z ∣ x ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] + DKL ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ) ) ≥ E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 :T ∣ x 0 ) ] \begin{aligned} \log p(\boldsymbol{x}) & =\log p(\boldsymbol{x}) \int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) dz \\ & =\int q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})(\log p(\boldsymbol{x})) dz \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}[\log p(\boldsymbol{x})] \\ & =\mathbb {E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{p (\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\ left[\log \frac{p(\boldsymbol{x},\boldsymbol{z}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}{p(\boldsymbol{z} \mid \boldsymbol{x}) q_{\phi}(\boldsymbol {z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\ log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+\mathbb {E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x })}{p(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & =\mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{ x})}\left[\log \frac{p(\boldsymbol{x},\boldsymbol{z})}{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\right]+D_{\mathrm{KL}}\left(q_{\boldsymbol{\phi }}(\boldsymbol{z} \mid \boldsymbol{x}) \| p(\boldsymbol{z} \mid \boldsymbol{x})\right) \\ & \geq \mathbb{E}_{q_{ \phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \frac{p(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi} }(\boldsymbol{z} \mid \boldsymbol{x})}\right] \\ & = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1: T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}( \boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{整列}T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}( \boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{整列}T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}}( \boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] \end{整列}ログ_p ( x )=ログ_p ( x )∫qϕ( z∣x )dz=∫qϕ( z∣x )(ログ_p ( x )) d z=Eqϕ( z ∣ x )[ログ_p ( x )]=Eqϕ( z ∣ x )[ログ_p ( z∣×)p ( x ,z )】=Eqϕ( z ∣ x )[ログ_p ( z∣x ) qϕ( z∣×)p ( x ,z ) qϕ( z∣×)】=Eqϕ( z ∣ x )[ログ_qϕ( z∣×)p ( x ,z )】+Eqϕ( z ∣ x )[ログ_p ( z∣×)qϕ( z∣×)】=Eqϕ( z ∣ x )[ログ_qϕ( z∣×)p ( x ,z )】+Dクアラルンプール( qϕ( z∣x )∥p( z∣×))≥Eqϕ( z ∣ x )[ログ_qϕ( z∣×)p ( x ,z )】=Eqϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣バツ0)p私( ×0 : T)]
ここにありますz はx 1を表します: T x_{1:T}バツ1 : T、××x はx 0 x_0を表しますバツ0。途中でベイズ公式が使われます。
ここで 2 つの方法によって得られる結論は次のとおりです。
log p θ ( x 0 ) ≥ E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \log p_{\theta }(\boldsymbol{x}_{0}) \geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1 :T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_{\boldsymbol{\phi}} (\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right]ログ_p私( ×0)≥Eqϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣バツ0)p私( ×0 : T)】
その中E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{\boldsymbol{\phi }}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}} )}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right]Eqϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣ x0)p私( ×0 : T)]就是ELBO(下限証拠) (下限証拠)( Evid e nce Lower Bound ) 、つまり変分下限。_ _ _ _ _ _
損失関数を作成したいのですが、 log p θ ( x 0 ) -\log p_{\theta}\left(x_{0}\right)−ログ_p私( ×0)最小,就是另ELBOE q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] \mathbb{E}_{q_{ \boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{ 0:T}})}{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right]Eqϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣ x0)p私( ×0 : T)]マックス。
我们令LVLB = − E q ϕ ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) q ϕ ( x 1 : T ∣ x 0 ) ] = E q ϕ ( x 1 : T ∣ x 0 ) [ log q ϕ ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] L_{VLB} = -\mathbb{E}_{q_{\boldsymbol{\phi}}(\ボールドシンボル{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{p_{\theta }(\boldsymbol{x_{0:T}})}{q_ {\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\right] = \mathbb{E}_{q_{\boldsymbol{\phi }}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0})}\left[\log \frac{q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{ x_{1:T}}} \mid \boldsymbol{x_0})}{p_{\theta }(\boldsymbol{x_{0:T}})}\right]LVLB _ _=−E _qϕ( ×1 : T∣ x0)[ログ_qϕ( ×1 : T∣ x0)p私( ×0 : T)】=Eqϕ( ×1 : T∣ x0)[ログ_p私( ×0 : T)qϕ( ×1 : T∣ x0)] を計算し、損失関数を解くように変換します。さらに細分化します。
\mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}\mid\mathbf{x}_{t}\right)\right)}_{L_{t-1}}-\underbrace{ \log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}] \end{aligned} \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1}\mid\mathbf{x}_{t}\right)\right)}_{L_{t-1}}-\underbrace{ \log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}] \end{aligned}LVLB=Eq ( x0 : T)[ログ_p私( ×0 : T)q( ×1 : T∣バツ0)】=Eq[ログ_p私( ×T)∏t = 1Tp私( ×t − 1∣バツた)∏t = 1Tq( ×た∣バツt − 1)】=Eq[ −ログ_p私( ×T)+t = 1∑Tログ_p私( ×t − 1∣バツた)q( ×た∣バツt − 1)】=Eq[ −ログ_p私( ×T)+t = 2∑Tログ_p私( ×t − 1∣バツた)q( ×た∣バツt − 1)+ログ_p私( ×0∣バツ1)q( ×1∣バツ0)】=Eq[ −ログ_p私( ×T)+t = 2∑Tログ_(p私( ×t − 1∣バツた)q( ×t − 1∣バツた、バツ0)⋅q( ×t − 1∣バツ0)q( ×た∣バツ0))+ログ_p私( ×0∣バツ1)q( ×1∣バツ0)】=Eq[ −ログ_p私( ×T)+t = 2∑Tログ_p私( ×t − 1∣バツた)q( ×t − 1∣バツた、バツ0)+t = 2∑Tログ_q( ×t − 1∣バツ0)q( ×た∣バツ0)+ログ_p私( ×0∣バツ1)q( ×1∣バツ0)】=Eq[ −ログ_p私( ×T)+t = 2∑Tログ_p私( ×t − 1∣バツた)q( ×t − 1∣バツた、バツ0)+ログ_q( ×1∣バツ0)q( ×T∣バツ0)+ログ_p私( ×0∣バツ1)q( ×1∣バツ0)】=Eq[ログ_p私( ×T)q( ×T∣バツ0)+t = 2∑Tログ_p私( ×t − 1∣バツた)q( ×t − 1∣バツた、バツ0)−ログ_p私( ×0∣バツ1) ]=Eq[LT
Dクアラルンプール( q( ×T∣バツ0)∥p _私( ×T) )+t = 2∑TLt − 1
Dクアラルンプール( q( ×t − 1∣バツた、バツ0)∥p _私( ×t − 1∣バツた) )−L0
ログ_p私( ×0∣バツ1)]
ここで、最初の添字はq ϕ ( x 1 : T ∣ x 0 ) q_{\boldsymbol{\phi}}(\boldsymbol{\boldsymbol{x_{1:T}}} \mid \boldsymbol{x_0}) から始まります。qϕ( ×1 : T∣バツ0)は q ( x 0 : T ) q\left(\mathbf{x}_{0: T}\right)に変更されます。q( ×0 : T)、私はx 0 x_0 だバツ0は既知であるため、2 つの式は同等です。
真ん中にもう 1 つわかりにくい点があります。q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)q( ×1 : T∣バツ0)は順伝播の分布q ( xt − 1 ∣ xt ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right) を表します。q( ×t − 1∣バツた)逆拡散過程の真の分布を表します。ここで、qqqには順方向または逆方向の意味はなく、確率と分布を表すだけであり、確率論ではPPP。p θ ( xt − 1 ∣ xt ) p_{\theta}\left(\mathbf{x}_{t-1}\mid \mathbf{x}_{t}\right)p私( ×t − 1∣バツた)は、解決したい逆拡散分布を表します。
次に、 LT 、 L t − 1 、 L 0 L_{T}、 L_{t-1}、 L_{0} があります。LT、Lt − 1、L0これら 3 つの状況を分類して説明します。
2.3.2 LT L_{T}LT
q ( x 1 : T ∣ x 0 ) q\left(\mathbf{x}_{1:T} \mid \mathbf{x}_{0}\right)q( ×1 : T∣バツ0)は順拡散プロセスを表し、学習可能なパラメータはありません;p θ ( x T ) p_{\theta}\left(\mathbf{x}_{T}\right)p私( ×Tx T x_Tで)バツTは、標準ガウス分布p θ p_{\theta}に従うノイズです。p私は逆拡散プロセスです。逆拡散プロセスの場合、x T x_TバツTは既知であるため、この用語LT L_{T}LT定数として使用できます。
2.3.3 L t − 1 L_{t-1}Lt − 1
そしてL t − 1 L_{t-1}Lt − 1実数逆拡散分布 q ( xt − 1 ∣ xt , x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf であることがわかります。{ x}_{0}\右)q( ×t − 1∣バツた、バツ0)と、必要な逆拡散分布p θ ( xt − 1 ∣ xt ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)p私( ×t − 1∣バツた) KL ダイバージェンス。
- 関数q ( xt − 1 ∣ xt , x 0 ) q\left(\mathbf{x}_{t-1}\mid \mathbf{x}_{t}, \mathbf{x}_{0}\右)q( ×t − 1∣バツた、バツ0)得られた平均と分散:
μ ~ t = 1 α t ( xt − 1 − α t 1 − α ˉ t ε t ) 、 β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \tilde{\mu}_{t}=\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{1-\alpha_{t}}{\ sqrt{ 1-\bar{\alpha}_{t}}} \varepsilon_{t}\right)、\tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t -1 }}{1-\bar{\alpha}_{t}} \cdot \beta_{t}メートル~た=あるた1( ×た−1−あるˉた1−あるたeた)、b~た=1−あるˉた1−あるˉt − 1⋅bた - 2 番目の分布p θ ( xt − 1 ∣ xt ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)p私( ×t − 1∣バツた)は当てはめたいターゲット分布であり、ガウス分布でもあり、平均はネットワークによって推定され、分散はβ t \beta_tbた形式:
p θ ( xt − 1 ∣ xt ) = N ( xt − 1 ; μ θ ( xt , t ) , Σ θ ( xt , t ) ) p_{\theta}\left(\mathbf{x}_{t -1} \mid \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \ボール記号{\mu}_{\theta} \ left(\mathbf{x}_{t}, t\right), \ballsymbol{\sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right);p私( ×t − 1∣バツた)=N( ×t − 1;メートル私( ×た、t )、S私( ×た、た))
したがって、これら 2 つの分布を近づけるには、分散を無視することができ、2 つの分布の平均間の距離を最小化するだけで済みます。第 2 ノルムを使用して次のように表します: L t = E q [ ∥ μ ~
t ( xt , x 0 ) − μ θ ( xt , t ) ∥ 2 ] = E x 0 , ϵ [ ∥ 1 α t ( xt ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) − μ θ ( xt ( x 0 , ϵ ) , t ) ∥ 2 ] ϵ ∼ N ( 0 , 1 ) \begin{aligned} L_{t} & =\mathbb{E}_{q}\left[\left\| \チルダ{\boldsymbol{\ mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)-\boldsymbol{\mu}_{\theta }\left(\mathbf{ x}_{t}, t\right)\right\|^{2}\right] \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\ left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0} 、\epsilon\right)- \frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\boldsymbol{\mu}_{\theta }\left(\mathbf{ x}_{t}\left(\mathbf{x}_{0}, \epsilon\right), t\right)\right\|^{2}\right] \quad \ epsilon \sim \mathcal{N }(0,1) \end{aligned}Lた=Eq[ ∥メートル~た( ×た、バツ0)−メートル私( ×た、t ) ∥2 ]=Eバツ0、 ϵ[
あるた1( ×た( ×0、) _−1−あるˉたbた) _−メートル私( ×た( ×0、) _、t )
2 ]ϵ〜N ( 0 ,1 )
この式では、 μ θ ( xt ( x 0 , ϵ ) , t ) \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left を使用する必要があることがわかります。( \mathbf{x}_{0}, \epsilon\right), t\right)メートル私( ×た( ×0、) _、t )内 1α t ( xt ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x} _{t }\left(\mathbf{x}_{0}, \epsilon\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\右)あるた1( ×た( ×0、) _−1 −あるˉたbたϵ )、定義:
μ θ ( xt , t ) = 1 α t ( xt − β t 1 − α ˉ t ϵ θ ( xt , t ) ) \太字記号{\mu}_{\theta}\left( \ mathbf{x}_{t}, t\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{ t }}{\sqrt{1-\bar{\alpha}_{t}}}\epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right)メートル私( ×た、t )=あるた1( ×た−1−あるˉたbたϵ私( ×た、t ) )
は、ニューラル ネットワークϵ θ ( xt , t ) \epsilon_{\theta}\left(\mathbf{x}_{t}, t\right) を直接ϵ私( ×た、t )を使用してノイズϵ \epsilonϵ。次に、予測されたノイズを定義された式に取り込み、予測平均を計算します。
したがって、損失関数は次のようになります。
L t = E x 0 , ϵ [ ∥ 1 α t ( xt − β t 1 − α ˉ t ϵ ) − 1 α t ( xt − β t 1 − α ˉ t ϵ θ ( xt , t ) ) ∥ ϵ 〜 N ( 0 , 1 ) = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( xt , t ) ∥ 2 ] ϵ 〜 N ( 0 , 1 ) ϵ 〜 N ( 0 , 1 ) を無限小の連続ソルバーとする训练= E x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ tx 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] , ϵ 〜 N ( 0 , 1 ) \begin{aligned} L_{t} & = \mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf {x }_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon\right)-\frac{1}{\sqrt{ \alpha_ {t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}}\epsilon_{ \theta }\left(\mathbf{x}_{t},t\right)\right)\right\|^{2}\right]\quad\epsilon\sim\mathcal{N}(0,1) \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\mathbf{x}_{ t} , t\right)\right\|^{2}\right] \quad \epsilon\sim\mathcal{N}(0,1) \quad \text {定数項の係数をすべて捨ててください。著者はこれがトレーニングに適していると言っています} \\ & =\mathbb{E}_{\mathbf{x}_{0}, \epsilon} \left[\ left\|\epsilon-\epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar {\alpha} _{t}} \epsilon, t\right)\right\|^{2}\right], \quad \epsilon \sim \mathcal{N}(0,1) \end{aligned}Lた=Eバツ0、 ϵ[
あるた1( ×た−1−あるˉたbた) _−あるた1( ×た−1−あるˉたbたϵ私( ×た、た))
2 ]ϵ〜N ( 0 ,1 )=Eバツ0、 ϵ[ ∥ ϵ−ϵ私( ×た、t ) ∥2 ]ϵ〜N ( 0 ,1 ) 定数項目の係数は捨てて、 訓練したほうが良いと著者は言う =Eバツ0、 ϵ[
ϵ−ϵ私(あるˉたバツ0+1−あるˉたϵ 、t )
2 ]、ϵ〜N ( 0 ,1 )
ネットワークへの入力は、ノイズと線形結合された画像xt x_tです。バツた、それに組み合わされたノイズの真の値はϵ \epsilonです。ϵ、自由関数θ ( α ˉ tx 0 + 1 − α ˉ t ϵ , t ) \epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x} _ {0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon, t\right)ϵ私(あるˉたバツ0+1−あるˉたϵ 、t )をこのノイズに適合させます。
2.3.4 L 0 L_{0}L0
最終的なL 0 = − log p θ ( x 0 ∣ x 1 ) L_{0}=-\log p_{\theta}\left(x_{0} \mid x_{1}\right)L0=−ログ_p私( ×0∣バツ1)は最後のステップのノイズが含まれた画像x 1 x_1バツ1ノイズ除去画像の生成x 0 x_0バツ0の最尤推定。より良い画像を生成するには、画像上の各ピクセル値が離散対数尤度を満たすように、各ピクセルに対して最尤推定を使用する必要があります。
これを達成するために、逆拡散プロセスの最後の部分がx 1 から x_1 に変更されます。バツ1x 0 x_0へバツ0変換は独立した離散計算に設定されます。つまり、特定のx 1 x_1での最後の変換プロセスでバツ1画像を取得x 0 x_0バツ0ピクセルが互いに独立していると仮定して、対数尤度を満たします。
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D p θ ( x 0 i ∣ x 1 i ) p_{\theta}\left (x_ {0} \mid x_{1}\right)=\prod_{i=1}^{D} p_{\theta}\left(x_{0}^{i} \mid x_{1}^{ i} \右)p私( ×0∣バツ1)=i = 1∏Dp私( ×0私は∣バツ1私は)
DDDはxxですxの次元、上付き文字iiiは画像内の座標位置を表します。ここでの目標は、特定のピクセルの値がどの程度可能性があるかを判断すること、つまり、対応するタイム ステップt = 1 t=1t=1ノイズの少ない画像xxx内の対応するピクセル値の分布
: N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left (x_ {1}, 1\right), \sigma_{1}^{2}\right)N( x ;メートル私私は( ×1、1 )、p12)
ここで、t = 1 t=1t=1のピクセル分布は、多変量ガウス分布から得られます。その対角共分散行列を使用すると、分布を一変量ガウスの積に分割できます: N (
x ; μ θ ( x 1 , 1 ) , σ 1 2 I ) = ∏ i = 1 DN ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}\left(x_{1}, 1\right), \sigma_ {1}^{2} \mathbb{I}\right)=\prod_{i=1}^{D} \mathcal{N}\left(x ; \mu_{\theta}^{i}\left( x_{1}, 1\right), \sigma_{1}^{2}\right)N( x ;メートル私( ×1、1 )、p12私)=i = 1∏DN( x ;メートル私私は( ×1、1 )、p12)
ここで、画像が値 0 ~ 255 の範囲 [-1,1] で正規化されていると仮定します。t=0 における各ピクセルのピクセル値を考えると、遷移確率分布p θ ( x 0 ∣ x 1 ) p_{\theta}\left(x_{0} \mid x_{ 1}\right)p私( ×0∣バツ1したがって、 p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D ∫ δ − ( x 0
i ) δ + ( x 0 i ) N ( x ; μ θ i ) ( x 1 , 1 ) , σ 1 2 ) dx δ + ( x ) = { ∞ if x = 1 x + 1 255 if x < 1 δ − ( x ) = { − ∞ if x = − 1 x − 1 255 if x > − 1 \begin{aligned} p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right) & =\prod_{i=1} ^{D} \int_{\delta_{-}\left(x_{0}^{i}\right)}^{\delta_{+}\left(x_{0}^{i}\right)} \ mathcal{N}\left(x ; \mu_{\theta}^{i}\left(\mathbf{x}_{1}, 1\right), \sigma_{1}^{2}\right) dx \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array}{ll} -\infty & \text { if } x=-1 \\ x- \frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{整列}p私( ×0∣バツ1)d+( x )=i = 1∏D∫d−( ×0私は)d+( ×0私は)N( x ;メートル私私は( ×1、1 )、p12)dx _={
∞バツ+2551× の場合 =1× の場合 <1d−( × )={
− ∞バツ−2551× の場合 =− 1× の場合 >− 1
この公式は元の論文から来ており、ここではその意味を分析します。つまり、最後のステップx 1 x_1のノイズ画像を追加したいとします。バツ1ノイズ除去された画像をフィットx 0 x_0バツ0、画像の各ピクセルをガウス分布に設定し、合計DDDピクセル。一方x 0 x_0バツ0各ピクセルの元の値の範囲は{ 0 , 1 , … , 255 } \{0,1, \ldots, 255\}です。{
0 ,1 、…、255 } 、正規化後に[ − 1 , 1 ] [-1,1][ − 1 、1 ]の範囲です。
ここで、単一のx 1 x_1を取り出します。バツ1x 1 i x_1^i上のピクセルバツ1私は、関数N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) \mathcal{N}\left(x ; \mu_{\theta}^{i}\left(x_{1}, 1\右)、\sigma_{1}^{2}\右)N( x ;メートル私私は( ×1、1 )、p12)、当てはめられるターゲットはx 0 x_0バツ0対応する位置画素点x 0 i x_0^iバツ0私は ,而 x 0 i x_0^i バツ0私は元の離散空間の値の範囲{ 0 , 1 , … , 255 } \{0,1, \ldots, 255\}{
0 ,1 、…、255 }は連続空間[ − 1 , 1 ] [-1,1][ − 1 、1 ]であるため、元の各離散値は連続空間の間隔に対応し、間隔マッピングの式は次のようになります。
δ + ( x ) = { ∞ if x = 1 x + 1 255 if x < 1 δ − ( x ) = { − ∞ if x = − 1 x − 1 255 if x > − 1 \begin{aligned} \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)=\left\{\begin{array }{ll} -\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1 \end{array}\right.\right. \end{整列}d+( x )={
∞バツ+2551× の場合 =1× の場合 <1d−( × )={
− ∞バツ−2551× の場合 =− 1× の場合 >− 1
上記は、損失関数構築部分を含む、DDPM の拡散および逆拡散プロセスに含まれるすべての式の解析と導出です。
参考文献とブログ
拡散モデルを理解する: 統一された視点による
ノイズ除去拡散確率モデル
https://yinglinzheng.netlify.app/diffusion-model-tutorial
https://zhuanlan.zhihu.com/p/549623622