Diffusion Model 浅学笔记

Diffusion Model

Created by: 银晗 张
Created time: May 29, 2023 8:12 AM

VAE → GAN →Diffusion

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-E2gVx3hv-1690185475251)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled.png)]

要讲扩散模型,不得不提VAE。VAE和GAN一样,都是从隐变量Z生成目标数据X。

它们假设隐变量服从某种常见的概率分布(比如正态分布),然后希望训练一个模型

X = g ( Z ) X=g(Z) X=g(Z),这个模型将原来的概率分布映射到训练集的概率分布,也就是分布的变换。

  • 注意,VAE和GAN的本质都是概率分布的映射

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PL7nPpNo-1690185475252)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%201.png)]

是不是听上去很work?

但是这种方法本质上是难以work的,因为尽量接近并没有一个确定的关于 X X X X ˉ \bar{X} Xˉ的相似度的评判标准。

换句话说,这种方法的难度就在于,必须去猜测“它们的分布相等吗”这个问题,而缺少真正interpretable的价值判断。

有聪明的同学会问,KL散度不就够了吗?不行,因为KL散度是针对两个已知的概率分布求相似度的,而 X ˉ 和 X \bar{X}和X XˉX概率分布目前都是未知

VAE

它本质上就是在我们常规的自编码器的基础上,对 encoder 的结果(在VAE中对应着计算均值的网络)加上了“高斯噪声”,使得结果 decoder 能够对噪声有鲁棒性;而那个额外的 KL loss(目的是让均值为 0,方差为 1),事实上就是相当于对 encoder 的一个正则项,希望 encoder 出来的东西均有零均值。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Jk3vXJzX-1690185475253)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%202.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xUP9lMag-1690185475253)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%203.png)]

为了使模型具有生成能力,VAE 要求每个 p(Z_X) 都向正态分布看齐。

那怎么让所有的 p(Z|X) 都向 N(0,I) 看齐呢?如果没有外部知识的话,其实最直接的方法应该是在重构误差的基础上中加入额外的 loss:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WrMR0xCl-1690185475254)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%204.png)]

因为它们分别代表了均值 μ k μ_k μk 和方差的对数 l o g σ 2 logσ^2 logσ2,达到 ∗ N ( 0 , I ) ∗ *N(0,I)* N(0,I) 就是希望二者尽量接近于 0 了。不过,这又会面临着这两个损失的比例要怎么选取的问题,选取得不好,生成的图像会比较模糊。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h9e9n1Pu-1690185475255)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%205.png)]

这里的 d 是隐变量 Z 的维度,而 μ ( i ) μ(i) μ(i) σ ( i ) 2 σ_{(i)}^{2} σ(i)2 分别代表一般正态分布的均值向量和方差向量的第 i 个分量。直接用这个式子做补充 loss,就不用考虑均值损失和方差损失的相对比例问题了。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cVhegdEr-1690185475255)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%206.png)]

重参数技巧

其实很简单,就是我们要从 ∗ p ( Z ∣ X k ) ∗ *p(Z|X_k)* p(ZXk) 中采样一个 Z k Z_k Zk 出来,尽管我们知道了 ∗ p ( Z ∣ X k ) ∗ *p(Z|X_k)* p(ZXk) 是正态分布,但是均值方差都是靠模型算出来的,我们要靠这个过程反过来优化均值方差的模型,但是“采样”这个操作是不可导的,而采样的结果是可导的,于是我们利用了一个事实:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DbK7wkfw-1690185475256)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%207.png)]

所以,我们将从 N ( μ , σ 2 ) N(μ,σ^2) N(μ,σ2) 采样变成了从 ∗ N ( 0 , σ 2 ) ∗ *N(0,σ^2)* N(0,σ2) 中采样,然后通过参数变换得到从 ∗ N ( μ , σ 2 ) ∗ *N(μ,σ^2)* N(μ,σ2) 中采样的结果。

这样一来,“采样”这个操作就不用参与梯度下降了,改为采样的结果参与,使得整个模型可训练了。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-i5ClVYKc-1690185475257)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%208.png)]

Diffusion

  • VAE的生成器,是将标准高斯映射到数据样本(自己定义的)。VAE的后验分布,是将数据样本映射到标准高斯(学出来的)。

那反过来,我想要设计一种方法A,使得A用一种简单的“变分后验”将数据样本映射到标准高斯(自己定义的),并且使得A的生成器,将标准高斯映射到数据样本(学出来的)

  • 注意,因为生成器的搜索空间大于变分后验,VAE的效率远不及A方法:因为A方法是学一个生成器(搜索空间大),所以可以直接模仿这个“变分后验”的每一小步

所以,A学的是样本到标准高斯分布 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)的的映射关系 f f f,如何学?马尔可夫链


马尔可夫链:

  • 最重要的性质:平稳性

一个概率分布如果随时间变化,那么在马尔可夫链的作用下,它一定会趋于某种平稳分布(例如高斯分布)。只要终止时间足够长,概率分布就会趋近于这个平稳分布。

这个逐渐逼近的过程被作者称为前向过程(forward process)。**注意,这个过程的本质还是加噪声!**试想一下为什么……其实和VAE非常相似,都是在随机采样!马尔可夫链每一步的转移概率,本质上都是在加噪声。这就是扩散模型中“扩散”的由来:噪声在马尔可夫链演化的过程中,逐渐进入diffusion体系。

**物理扩散过程:**随着时间的推移,加入的噪声(加入的溶质)越来越少,而体系中的噪声(这个时刻前的所有溶质)逐渐在diffussion体系中扩散,直至均匀。

  • 扩散模型的本质基于马尔可夫链的前向过程,其每一个epoch的逆过程都可以近似为高斯分布。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IcQHfeCO-1690185475258)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%209.png)]

前向过程通过马尔可夫链的转移概率不断加入噪音,从右边的采样数据到左边的标准高斯;

反向过程通过SDE来“抄袭”对应正向过程的那一个epoch的行为(其实每一步都不过是一个高斯分布),从而逐渐学习到对抗噪声的能力。高斯分布是一种很简单的分布,运算量小,这一点是diffusion快的最重要原因。

公式推导

前向过程:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-J0fF4gqO-1690185475259)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%2010.png)]

任意时刻的 X t X_t Xt 可以由 X 0 X_0 X0 β \beta β 表示

  • 方差系数: 1 − β \sqrt{1-\beta} 1β ; 均值系数: β \beta β

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Pqsk5A2C-1690185475260)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%2011.png)]

DDPM的每一步的推断可以总结为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pBDE7Qtj-1690185475260)(Diffusion%20Model%20f63a7539928247c8aec2be9d29737ab3/Untitled%2012.png)]

猜你喜欢

转载自blog.csdn.net/RandyHan/article/details/131898380
今日推荐