Diffusion Model (扩散生成模型)的基本原理详解(二)Score-Based Generative Modeling(SGM)

本篇是《Diffusion Model (扩散生成模型)的基本原理详解(一)Denoising Diffusion Probabilistic Models(DDPM)》的续写,继续介绍有关diffusion的另一个相关模型,同理,参考文献和详细内容与上一篇相同,读者可自行查阅,本篇着重介绍Score-Based Generative Modeling(SGM)的部分,本篇的理论部分参考与上一节相同,当然涉及了一些原文的理论部分,笔者在这里为了更能让各位读懂,略掉了原文的一些理论证明,感兴趣读者可以自行阅读Song Yang et al.SGM原文。笔者只介绍重要思想和重要理论,省略了较多细节篇幅。下一节介绍本基础系列最后一部重点:Stochastic Differential Equation(SDE)。

2、Score-Based Generative Models(SGM)

不同于DDPM,这是一个基于分数的Model,换而言之通过预测评分来获取最终的信息。在阅读本篇之前,这样我们显然会产生两个问题:
一、Network应该以什么为基准的评分?
二、DDPM网络是一个可以直接给出的后验分布(去噪链),那么如果我得到了一个基于评分的网络,如何进行所谓的“采样”,这里没有直接对分布进行所谓的预测。
下面开始进入正题介绍,先从评分函数讲起。

2.1、Score-Function(评分函数)

SGMS的核心思想在于评分函数的定义,让我们来一起看一下它的Score-Function的定义是怎么样定义的,以下是它的定义:

假若有一个概率密度函数 p ( x ) p(x) p(x),定义“Stein-ScoreFunction”如下:
S t e i n − S c o r e F u n c t i o n = ∇ x l o g ( p ( x ) ) Stein-ScoreFunction=\nabla_xlog(p(x)) SteinScoreFunction=xlog(p(x))显然的,该score表示了概率密度函数增长的快慢程度。

2.2、SGM forward Markov Chain(加噪链)—— q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)

我们仍旧假设原始数据 x 0 x_0 x0是从某一分布 x 0 ~ q ( x 0 ) x_0~q(x_0) x0q(x0)中采样得到。不同于之前DDPM,SGM使用另外一种更直接化的分布来进行采样,具体操作如下:假设生成噪声步长为 T T T,给予一组逐步弱化的噪声: [ σ 1 , σ 2 , σ 3 ⋅ ⋅ ⋅ σ T ] [\sigma_1,\sigma_2,\sigma_3···\sigma_T] [σ1,σ2,σ3⋅⋅⋅σT],则会有如下结论:
在第 i i i步噪声数据 x i x_i xi满足从如下分布中进行采样:
x i ~ N ( x 0 , σ i 2 I )    ⟺    q ( x i ∣ x 0 ) = N ( x 0 , σ i 2 I )    ⟺    q ( x i ) = ∫ q ( x i ∣ x 0 ) q ( x 0 ) d ( x 0 ) x_i~N(x_0,\sigma_i^2I)\iff q(x_i|x_0)=N(x_0,\sigma_i^2I)\iff q(x_i)=\int q(x_i|x_0)q(x_0)d(x_0) xiN(x0,σi2I)q(xix0)=N(x0,σi2I)q(xi)=q(xix0)q(x0)d(x0)

2.3、Score-Network—— s θ ( x t , t ) s_\theta(x_t,t) sθ(xt,t)

2.3.1、SGM与DDPM的一致性&Loss-Function

不同于DDPM,这里并不是直接训练出一个去噪链直接解决问题并生成数据,我们的目的是想训练出一个可以模拟Score-Function的良好网络,再基于该Score-Function进行反向的采样,即想要设计一个Network : s θ ( x t , t ) s_\theta(x_t,t) sθ(xt,t)用来模拟当前的Score-Function: ∇ x t l o g ( q ( x t ∣ x 0 ) ) \nabla_{x_{t}} log(q(x_t|x_0)) xtlog(q(xtx0))。那么显然地,目标函数变为:
L o s s = ∣ ∣ ∇ x t l o g ( q ( x t ∣ x 0 ) ) − s θ ( x t , t ) ∣ ∣ 2 Loss=||\nabla_{x_{t}} log(q(x_t|x_0))-s_\theta(x_t,t)||^2 Loss=∣∣xtlog(q(xtx0))sθ(xt,t)2
而我们已经知道了 q ( x t ∣ x 0 ) = N ( x 0 , σ t 2 I ) = 1 2 π σ t e − ( x t − x 0 ) 2 2 σ t 2 q(x_t|x_0)=N(x_0,\sigma_t^2I)=\frac{1}{\sqrt{2\pi}\sigma_t}e^{-\frac{(x_t-x_0)^2}{2\sigma_t^{2}}} q(xtx0)=N(x0,σt2I)=2π σt1e2σt2(xtx0)2
那么显然地,我们会有 ∇ x t l o g [ q ( x t ∣ x 0 ) ] = − x t − x 0 σ t 2 \nabla_{x_{t}}log[q(x_t|x_0)]=-\frac{x_t-x_0}{\sigma_t^{2}} xtlog[q(xtx0)]=σt2xtx0
则会有当前优化目标可以视为如下的函数
L o s s = ∣ ∣ − x t − x 0 σ t 2 − s θ ( x t , t ) ∣ ∣ 2 Loss=||-\frac{x_t-x_0}{\sigma_t^{2}}-s_\theta(x_t,t)||^2 Loss=∣∣σt2xtx0sθ(xt,t)2
注意到第一项,这可视为从正态分布进行的采样,差了一个非网络参数 σ t \sigma_t σt即:
L o s s ∗ = σ t 2 L o s s = ∣ ∣ x t − x 0 σ t + σ t s θ ( x t , t ) ∣ ∣ 2 Loss^*=\sigma_t^2Loss=||\frac{x_t-x_0}{\sigma_t}+\sigma_ts_\theta(x_t,t)||^2 Loss=σt2Loss=∣∣σtxtx0+σtsθ(xt,t)2
− σ t s θ ( x t , t ) = z θ ( x t , t ) -\sigma_ts_\theta(x_t,t)=z_\theta(x_t,t) σtsθ(xt,t)=zθ(xt,t)
L o s s ∗ = σ t 2 L o s s = ∣ ∣ z − z θ ∣ ∣ 2 Loss^*=\sigma_t^2Loss=||z-z_\theta||^2 Loss=σt2Loss=∣∣zzθ2
如果读者已经度过了笔者写过的(一)DDPM部分,读者会惊奇的发现,这与DDPM的优化目标是一致的,从原理上,它们的目的是相同的。

2.3.2、SGM-Sampling(Langevin Monte Carlo)

SGM的采样办法有很多种,不同于DDPM的那种“一步一步”反向估计后验估计的办法,这里首先介绍使用Langevin Monte Carlo(基于Langevin Dynamics的一种办法)进行采样,这里介绍算法过程,有关Langevin Monte Carlo理论部分的介绍可见随机过程中的Important-Sampling等会有一些介绍,笔者之后会给予一些简单的补充资料来对Langevin Monte Carlo理论进行说明,这里读者可以认为它的目的是去模拟原始分布 q ( x 0 ) q(x_0) q(x0),然后直接用该分布采样生成数据。
先来介绍Langevin Monte Carlo采样算法的过程:
假设我们已经训练好了一个网络 s θ ( x t , t ) s_\theta(x_t,t) sθ(xt,t),它可以作为 ∇ x t l o g ( q ( x t ∣ x 0 ) ) \nabla_{x_{t}} log(q(x_t|x_0)) xtlog(q(xtx0))的近似,我们下面要利用该网络进行分布预估:给定比步长(固定)为 s ∗ s^* s ( s ∗ (s^* (s足够小 ) ) ),迭代次数为 N N N。下面进行反向生成过程,笔者将其总结为SGM算法:

SGM算法(Langevin Monte Carlo法)
①、随机采样一个样本 x T 0 ~ N ( 0 , 1 ) x_T^{0}~N(0,1) xT0N(0,1) ( T (T (T足够大 ) ) ),记录当前时间 t = T t=T t=T
②、迭代 x t i + 1 = x t i + 1 2 s ∗ s θ ( x t i , t ) + s ∗ z x_t^{i+1}=x_t^{i}+\frac{1}{2}s^*s_\theta(x_t^{i},t)+\sqrt{s^*}z xti+1=xti+21ssθ(xti,t)+s z 直到 i = N i=N i=N
③、 x t − 1 0 = x t T , t = t − 1 x_{t-1}^{0}=x_{t}^T,t=t-1 xt10=xtT,t=t1
④、重复②~③直到 t = 0 t=0 t=0

猜你喜欢

转载自blog.csdn.net/lvoutongyi/article/details/129181819
今日推荐