【计算机视觉】VAE 讲解

VAE

1. 模型概述

变分自编码器(Variational AutoEncoder,VAE)属于生成模型。从概率图的角度看,VAE 是带隐变量的有向概率图模型;从神经网络的角度看,VAE 是以自编码器(AutoEncoder,AE)为框架的生成模型。VAE 通过在由编码器确定的分布中采样的结果作为解码器的输入以对输入图像进行重建,其关键之处在于编码器直接输出的不是潜在空间(编码空间)的具体特征,而是特征分布,这使得 VAE 成为与朴素自编码器不同的生成模型。

具体来说,将模型的解码器结构剥离出来,只有当输入为非常接近模型遇见过的潜在特征时,朴素自编码器模型的解码器才能输出具有合理语义的图像,对于变化比较大的潜在特征,编码器生成的图像不符合人们的预期,甚至无法被人们理解,因此不认为朴素自编码器具备生成能力;而 VAE 通过向潜在特征加入噪音,使得编码器能够对于潜在特征发生变化的输入也具有一定的处理能力,从而保证了其生成图像的合理性。

2. 模型结构

在这里插入图片描述

图 1    VAE 模型

变分自编码器的结构如图 1 1 1 所示。原始图像输入到编码器中得到 o o o α \alpha α,从标准高斯分布中采样确定 e e e,图像的潜在特征为 exp ⁡ ( α ) ⋅ e + o \exp(\alpha)·e+o exp(α)e+o,解码器接收潜在特征进行重建,损失函数由两部分组成:重建损失和正则化项。

2.1. 潜在空间正则化

自编码器的模型结构与 VAE 的结构只有少许的不同。在自编码器中,编码器部分的输出直接作为潜在特征 [ c 1 , c 2 , c 3 ] [c_1, c_2,c_3] [c1,c2,c3] 输入到解码器部分。这种结构设计上的差异导致自编码器在生成数据上无能为力,究其根本是因为自编码器学习到的是不规则的潜在空间。潜在空间的不规则性最直观的体现是,解码器无法将在潜在空间中随机采样的编码解码成合理的新数据,即潜在空间不具有连续性和完整性。连续性是指潜在空间中的两个相近的点在解码后不应给出两个完全不同的内容;完整性是指来自潜在空间的采样点都应该有合理的意义。如图 2 2 2 所示。

在这里插入图片描述

图 2    不规则的潜在空间 与 规则的潜在空间

自编码器的潜在空间出现不规则性的一个直观原因在于,在训练过程中模型忽略潜在空间的组织方式,仅考虑经过编码和解码后的重建图像损失尽可能少。每幅经过编码的图像都对应确定的潜在特征,相当于模型只学习到了潜在空间的零散点的意义,但是模型无法根据已知的点来理解性地解码其它点。

一种比较好的处理方法是,不再将输入图像编码为一个确定点,而是编码为潜在空间中的一个分布,解码器部分以分布中的采样点作为输入进行解码,计算损失。这种方式非常自然地对潜在空间进行了正则化,将编码器部分输出分布尽量限制为标准高斯分布,局部正则化由方差控制,全局正则化由均值控制。可以想象,如果仅要求编码器部分返回一般的高斯分布,也难以保证潜在空间具有连续性和完整性,比如当均值相差非常大、方差非常小,此时的编码器部分与自编码器的编码器部分无异,几乎可以认为输入图像被编码为确定的点,而非分布。为了避免这种情况,引入了对编码器部分输出分布的限制,防止图像在潜在空间中的编码相差甚远,并且鼓励每幅图像对应的分布之间有一定的重合,以保证潜在空间的连续性和完整性。如图 3 3 3 所示。有关正则化作用的一个更形象的例子如图 4 4 4 所示。

在这里插入图片描述

图 3    通过正则化的方式让潜在空间具有良好的性质

在这里插入图片描述

图 4    正则化倾向于在潜在空间中编码的信息上创建“梯度”(理解为渐变即可)

3. 数学细节

3.1. 从 GMM 到 VAE

从概率图角度出发,变分自编码器作为生成模型,最想学习到的是图像的分布 P ( X ) P(X) P(X)。我们最熟悉的概率生成模型是高斯混合模型(GMM),其思想是从有限个类别中按照类别分布情况采样出类别,再从类别对应的高斯分布中采样以生成数据。形式化表示为 P ( X ) = ∑ z P ( z ) P ( X ∣ z ) P(X)=\sum\limits_z P(z)P(X\mid z) P(X)=zP(z)P(Xz),其中 Z Z Z 服从多项分布 P ( Z ) P(Z) P(Z) X ∣ z X\mid z Xz 服从(假设为一元)高斯分布 N ( μ z , σ z ) N(\mu^z, \sigma^z) N(μz,σz),高斯分布中的参数表示不同的 z z z 对应的 μ \mu μ σ \sigma σ

方便起见,用 σ \sigma σ 表示方差,而不是 σ 2 \sigma^2 σ2。大写表示分布的含义,小写表示具体的取值,比如 P ( X ) P(X) P(X) 表示输入图像的分布, P ( x ) P(x) P(x) 表示输入图像 x x x 对应的概率值。

VAE 不过是 GMM 的推广,即隐变量 Z Z Z 不再是有限种选法,而是无限种选法,这样 P ( X ) = ∫ z P ( z ) P ( X ∣ z ) P(X) = \int\limits_zP(z) P(X\mid z) P(X)=zP(z)P(Xz),其中 Z Z Z 服从(假设为一元)高斯分布 N ( 0 , 1 ) N(0, 1) N(0,1) X ∣ z X\mid z Xz 服从(假设为一元)高斯分布 N ( μ ( z ) , σ ( z ) ) N(\mu(z), \sigma(z)) N(μ(z),σ(z)),高斯分布中的参数表示 μ \mu μ σ \sigma σ 是关于 z z z 的函数,不同的 z z z 决定了不同的高斯分布。

可见,GMM 的 P ( X ) P(X) P(X) 是由有限个高斯分布带权相加得到,而 VAE 的 P ( X ) P(X) P(X) 是由无限个高斯分布带权相加得到。二者的对比见图 5 5 5

方便起见,高斯分布都假设为一元高斯分布,但其实很多情况下 VAE 中的 Z Z Z X ∣ z X\mid z Xz 服从高维高斯分布。

在这里插入图片描述

图 5    GMM(左)与VAE(右)

函数 μ ( ⋅ ) \mu(·) μ() σ ( ⋅ ) \sigma(·) σ() 由神经网络确定,该神经网络对应于图 1 1 1 中的解码器部分。本质上,解码器部分的输出应该是,在 Z Z Z 取值确定的前提下,分布 P ( X ∣ z ) P(X\mid z) P(Xz) 对应的具体均值和方差,但是在具体实现时一般认为解码器部分输出的 σ \sigma σ 为人为设定的超参数,解码器仅输出 μ \mu μ,且 μ \mu μ 作为重建图像,就有了图 1 1 1 中的解码器部分。这里直接使用 μ \mu μ 作为重建图像并非毫无意义,编码器部分的输出是一个高斯分布,最后需要将在分布上的采样点作为重建图像,只不过在具体实现中直接使用了采样概率最高的点,即均值作为采样点,这是一种比较合理的方式。

另外,无需担心由于假设分布 P ( Z ) P(Z) P(Z) 为标准高斯分布过于简单导致最终确定的分布 P ( X ~ ) P(\tilde X) P(X~) 质量不佳,这是因为解码器部分的神经网络完全可以拟合任何复杂的函数,进而保证分布足够复杂。

3.2. 损失函数

通过积分的方式计算 P ( X ~ ) = ∫ z P ( z ) P ( X ~ ∣ z ) d z P(\tilde X)=\int\limits_z P(z)P(\tilde X\mid z){\rm d}z P(X~)=zP(z)P(X~z)dz 是不现实的,实际操作中需要通过采样求和的方式近似,即 P ( X ~ ) = ∑ z P ( z ) P ( X ~ ∣ z ) P(\tilde X)=\sum\limits_{z}P(z)P(\tilde X\mid z) P(X~)=zP(z)P(X~z)。在潜在空间中随机采样得到 z z z 是不合理的,由于潜在空间是高维的,有明显的维数灾难问题,即随着空间维数的增加,要想比较合理地描述空间的分布,采样数需要爆炸增长。因此,在潜在空间中随机采样很可能与训练集中图像在潜在空间中的编码相差很远,甚至无关,那么由采样点解码出合理的图像的效果也就越差,这很显然不利于重建。

可以想象,对于某个采样 z z z 和某个图像 x x x P ( x ∣ z ) P(x\mid z) P(xz) 表示潜在编码为 z z z 时生成图像 x x x 的概率,如上面所说,这个概率接近于 0 0 0。进一步,某个重建出的图像 x ~ \tilde x x~ 对应于潜在空间中的采样是有限的,故有 P ( x ~ ) = ∑ z P ( z ) P ( x ~ ∣ z ) P(\tilde x) = \sum\limits_z P(z)P(\tilde x\mid z) P(x~)=zP(z)P(x~z) P ( x ~ ) P(\tilde x) P(x~) 为在 x ~ \tilde x x~ 处对分布 P ( X ~ ) P(\tilde X) P(X~) 的贡献。当所有 x ~ \tilde x x~ 对应的概率值 P ( x ~ ) P(\tilde x) P(x~) 都非常小时,最终确定的分布 P ( X ~ ) P(\tilde X) P(X~) 更像是均匀分布,我们知道均匀分布是不包含任何信息的,因此重建出的这个分布是无意义的、无价值的。

可见,合理地采样是非常有必要的,我们希望对于某个图像 x x x 在潜在空间中的采样尽可能与之相关。依据 Z Z Z 的后验分布 P ( Z ∣ x ) P(Z\mid x) P(Zx),可以通过 x x x 采样得到 z z z,这样的 z z z 包含了与 x x x 相关的丰富信息,大概率能够由此生成优质的 x ~ \tilde x x~。后验概率分布是未知的,根据变分思想,引入高斯分布 q ( Z ∣ x ) = N ( μ ′ ( x ) , σ ′ ( x ) ) q(Z\mid x)=N(\mu'(x), \sigma'(x)) q(Zx)=N(μ(x),σ(x)) 来近似求解。

其中,函数 μ ′ ( ⋅ ) \mu'(·) μ() σ ′ ( ⋅ ) \sigma'(·) σ() 与函数 μ ( ⋅ ) \mu(·) μ() σ ( ⋅ ) \sigma(·) σ() 类似,前者是通过编码器部分神经网络确定,后者是通过解码器部分神经网络确定。

以最大化对数似然为优化目标,即最大化 ∑ x log ⁡ P ( x ) \sum\limits_{x}\log P(x) xlogP(x),其中 x x x 为观测图像, P ( x ) P(x) P(x) 在重建分布中 x x x 对应的概率值。以训练集仅包含一个图像样本为例进行推导:
log ⁡ P ( x ) = ∫ z q ( z ∣ x ) log ⁡ P ( x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) P ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) q ( z ∣ x ) P ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) + q ( z ∣ x ) log ⁡ q ( z ∣ x ) P ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) log ⁡ q ( z ∣ x ) P ( z ∣ x ) d z = E L B o + D K L [ q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ] \begin{align} \log P(x) &= \int_z q(z\mid x)\log P(x) {\rm d}z \notag\\ &= \int_z q(z\mid x)\log \frac{P(z, x)}{P(z\mid x)} {\rm d}z\notag \\ &= \int_z q(z\mid x)\log \frac{P(z, x)}{q(z\mid x)}\frac{q(z\mid x)}{P(z\mid x)} {\rm d}z\notag \\ &= \int_z q(z\mid x)\log \frac{P(z, x)}{q(z\mid x)} + q(z\mid x)\log\frac{q(z\mid x)}{P(z\mid x)} {\rm d}z \notag\\ &= \int_z q(z\mid x)\log \frac{P(z, x)}{q(z\mid x)}{\rm d}z + \int_zq(z\mid x)\log\frac{q(z\mid x)}{P(z\mid x)} {\rm d}z \notag\\ &= ELBo + D_{KL}[q(z\mid x) || P(z\mid x)]\notag \end{align} logP(x)=zq(zx)logP(x)dz=zq(zx)logP(zx)P(z,x)dz=zq(zx)logq(zx)P(z,x)P(zx)q(zx)dz=zq(zx)logq(zx)P(z,x)+q(zx)logP(zx)q(zx)dz=zq(zx)logq(zx)P(z,x)dz+zq(zx)logP(zx)q(zx)dz=ELBo+DKL[q(zx)∣∣P(zx)]
其中, E L B o ELBo ELBo 为证据下界(Evidence Lower Bound), D K L D_{KL} DKL 表示 KL 散度。

利用 EM 算法的迭代方式来理解 log ⁡ P ( x ) \log P(x) logP(x) E L B o ELBo ELBo D K L D_{KL} DKL 以及最大化过程是是最直观的,因此先以交替迭代更新的方式来讲解。引入 q q q 的好处在于,因为 log ⁡ P ( x ) = ∫ z q ( z ∣ x ) log ⁡ P ( x ) d z \log P(x) = \int\limits_z q(z\mid x)\log P(x) {\rm d}z logP(x)=zq(zx)logP(x)dz,所以仅调整 q q q 不影响 log ⁡ P ( x ) \log P(x) logP(x)。由于 D K L ≥ 0 D_{KL}\ge 0 DKL0 恒成立,因此存在关系 log ⁡ P ( x ) ≥ E L B o \log P(x) \ge ELBo logP(x)ELBo,即 E L B o ELBo ELBo log ⁡ P ( x ) \log P(x) logP(x) 的下界。可以想象,如果调整 q q q 使得 D K L D_{KL} DKL 尽可能小,甚至为 0 0 0,那么此时再提高函数下界 E L B o ELBo ELBo 很可能就会让 log ⁡ P ( x ) \log P(x) logP(x) 上升以迭代的方式实现最大化 log ⁡ P ( x ) \log P(x) logP(x)。交替迭代过程如图 6 6 6 所示。

在这里插入图片描述

图 6    交替迭代过程

在 VAE 的神经网络中,无需关注具体的交替迭代过程,只需要定义一个合适的损失函数,满足最大化 log ⁡ P ( x ) \log P(x) logP(x),同时最小化 D K L [ q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ] D_{KL}[q(z\mid x) ||P(z\mid x)] DKL[q(zx)∣∣P(zx)]。显然,最佳的损失函数是 E L B o ELBo ELBo,即损失函数为 L = log ⁡ P ( x ) − D K L [ q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ] \mathcal L = \log P(x) - D_{KL}[q(z\mid x) || P(z\mid x)] L=logP(x)DKL[q(zx)∣∣P(zx)]。对应到结构上,编码器部分的作用是找到合适的 q ( z ∣ x ) q(z\mid x) q(zx) D K L D_{KL} DKL 尽可能小;解码器部分的作用是找到合适的 P ( x ∣ z ) P(x\mid z) P(xz) E L B o ELBo ELBo 尽可能提高。

将损失函数 L \mathcal L L 展开:
L = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( x ∣ z ) P ( z ) q ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) log ⁡ P ( x ∣ z ) d z = − D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] + E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] \begin{align} \mathcal L &= \int_z q(z\mid x)\log \frac{P(z, x)}{q(z\mid x)}{\rm d}z \notag \\ &=\int_z q(z\mid x)\log \frac{P(x\mid z)P(z)}{q(z\mid x)}{\rm d}z \notag\\ &=\int_z q(z\mid x)\log \frac{P(z)}{q(z\mid x)}{\rm d}z + \int_z q(z\mid x)\log P(x\mid z){\rm d}z\notag \\ &= -D_{KL}[q(z\mid x) || P(z)] + {\mathbb E}_{q(z\mid x)}[\log P(x\mid z)]\notag \end{align} L=zq(zx)logq(zx)P(z,x)dz=zq(zx)logq(zx)P(xz)P(z)dz=zq(zx)logq(zx)P(z)dz+zq(zx)logP(xz)dz=DKL[q(zx)∣∣P(z)]+Eq(zx)[logP(xz)]
已经假设 P ( z ∣ x ) P(z\mid x) P(zx) 服从一元高斯分布, P ( z ) P(z) P(z) 服从一元标准高斯分布,故有
D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] = D K L [ N ( μ ′ , σ ′ 2 ) ∣ ∣ N ( 0 , 1 ) ] = ∫ z 1 2 π σ ′ 2 exp ⁡ ( − ( z − μ ′ ) 2 2 σ ′ 2 ) log ⁡ 1 2 π σ ′ 2 exp ⁡ ( − ( z − μ ′ ) 2 2 σ ′ 2 ) 1 2 π exp ⁡ ( − z 2 2 ) d z = ∫ z ( − ( z − μ ′ ) 2 2 σ ′ 2 + z 2 2 − log ⁡ σ ′ ) N ( μ ′ , σ ′ 2 ) d z = − ∫ z ( z − μ ′ ) 2 2 σ ′ 2 N ( μ ′ , σ ′ 2 ) d z + ∫ z z 2 2 N ( μ ′ , σ ′ 2 ) d z − ∫ z log ⁡ σ ′ N ( μ ′ , σ ′ 2 ) d z = − E [ ( z − μ ′ ) 2 ] 2 σ ′ 2 + E [ z 2 ] 2 − log ⁡ σ ′ = 1 2 ( − 1 + σ ′ 2 + μ ′ 2 − log ⁡ σ ′ 2 ) \begin{align} D_{KL}[q(z\mid x) || P(z)] &= D_{KL}[N(\mu', \sigma'^2) || N(0, 1)]\notag \\ &= \int_z \frac{1}{\sqrt{2\pi\sigma'^2}}\exp\left( -\frac{(z-\mu')^2}{2\sigma'^2} \right) \log\frac{\frac{1}{\sqrt{2\pi\sigma'^2}}\exp\left( -\frac{(z-\mu')^2}{2\sigma'^2} \right)}{\frac{1}{\sqrt{2\pi}}\exp(-\frac{z^2}{2})}{\rm d}z\notag \\ &= \int_z\left( \frac{-(z-\mu')^2}{2\sigma'^2} + \frac{z^2}{2}-\log\sigma' \right)N(\mu', \sigma'^2){\rm d}z \notag\\ &= -\int_z\frac{(z-\mu')^2}{2\sigma'^2} N(\mu', \sigma'^2){\rm d}z+\int_z\frac{z^2}{2} N(\mu', \sigma'^2){\rm d}z - \int_z \log \sigma' N(\mu', \sigma'^2){\rm d}z\notag \\ &= -\frac{\mathbb E\left[ (z-\mu')^2\right]}{2\sigma'^2}+\frac{\mathbb E\left[z^2\right]}{2} - \log \sigma' \notag\\ &=\frac{1}{2} (-1+\sigma'^2+\mu'^2-\log\sigma'^2)\notag \end{align} DKL[q(zx)∣∣P(z)]=DKL[N(μ,σ′2)∣∣N(0,1)]=z2πσ′2 1exp(2σ′2(zμ)2)log2π 1exp(2z2)2πσ′2 1exp(2σ′2(zμ)2)dz=z(2σ′2(zμ)2+2z2logσ)N(μ,σ′2)dz=z2σ′2(zμ)2N(μ,σ′2)dz+z2z2N(μ,σ′2)dzzlogσN(μ,σ′2)dz=2σ′2E[(zμ)2]+2E[z2]logσ=21(1+σ′2+μ′2logσ′2)
将一元高斯分布推广到 d d d 元独立高斯分布,得:
D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] = ∑ j = 1 d 1 2 ( − 1 + σ ′ ( j ) 2 + μ ′ ( j ) 2 − log ⁡ σ ′ ( j ) 2 ) D_{KL}[q(z\mid x)|| P(z)] = \sum_{j=1}^d \frac{1}{2} (-1+{\sigma'^{(j)}}^2+{\mu'^{(j)}}^2-\log{\sigma'^{(j)}}^2) DKL[q(zx)∣∣P(z)]=j=1d21(1+σ(j)2+μ(j)2logσ(j)2)
其中 a ( j ) {a^{(j)}} a(j) 表示向量 a a a 的第 j j j 个元素。

通过采样的方式来近似求解期望部分,即:
E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] ≈ 1 m ∑ i = 1 m log ⁡ P ( x ∣ z i ) \mathbb E_{q(z\mid x)}[\log P(x\mid z)]≈\frac{1}{m} \sum_{i=1}^m \log P(x\mid z_i) Eq(zx)[logP(xz)]m1i=1mlogP(xzi)
其中, z i ∼ q ( z ∣ x i ) = N ( μ ′ ( x i ) , σ ′ ( x i ) ) z_i\sim q(z\mid x_i) = N(\mu'(x_i), \sigma'(x_i)) ziq(zxi)=N(μ(xi),σ(xi))。假设图像对应向量维度为 K K K,即 X ∣ z i X\mid z_i Xzi 服从 K K K 维高斯分布。根据 X ∣ z i ∼ P ( X ∣ z i ) = N ( μ ( z i ) , σ ( z i ) ) X\mid z_i\sim P(X\mid z_i) = N(\mu(z_i), \sigma(z_i)) XziP(Xzi)=N(μ(zi),σ(zi)) log ⁡ P ( x ∣ z i ) \log P(x\mid z_i) logP(xzi) 展开,有:
log ⁡ p θ ( x ∣ z i ) = log ⁡ exp ⁡ ( − 1 2 ( x − μ ) T Σ − 1 ( X − μ ′ ) ) ( 2 π ) k ∣ Σ ∣ = − 1 2 ( x − μ ) T Σ − 1 ( x − μ ) − log ⁡ ( 2 π ) k ∣ Σ ∣ = − 1 2 ∑ k = 1 K ( x ( k ) − μ ( k ) ) 2 σ ( k ) − log ⁡ ( 2 π ) K ∏ k = 1 K σ ( k ) \begin{align} \log p_{\theta}\left(x \mid z_{i}\right) &= \log \frac{\exp \left(-\frac{1}{2}(x-\mu^{})^{\mathrm{T}} {\Sigma}^{-1}({X}-{\mu^{\prime}})\right)}{\sqrt{(2 \pi)^{k}|{\Sigma^{}}|}}\notag \\ &= -\frac{1}{2}(x-\mu^{})^{\mathrm{T}} {\Sigma}^{-1}({x}-{\mu^{}}) - \log \sqrt{(2 \pi)^{k}|\Sigma^{}|}\notag \\ &= -\frac{1}{2} \sum_{k=1}^K \frac{(x^{(k)}-\mu^{(k)})^2}{\sigma^{(k)}} - \log \sqrt{(2 \pi)^{K}\prod_{k=1}^{K} \sigma^{(k)}}\notag \end{align} logpθ(xzi)=log(2π)kΣ exp(21(xμ)TΣ1(Xμ))=21(xμ)TΣ1(xμ)log(2π)kΣ =21k=1Kσ(k)(x(k)μ(k))2log(2π)Kk=1Kσ(k)
当训练集只包含一张图像时,损失函数可以写为:
L = − E L B o = D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] − E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] = D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] − 1 m ∑ i = 1 m log ⁡ P ( x ∣ z i ) \begin{align} \mathcal L&= - ELBo\notag \\ &= D_{KL} [q(z\mid x) || P(z)] - \mathbb E_{q(z\mid x)}[\log P(x\mid z)] \notag\\ &= D_{KL} [q(z\mid x) || P(z)] - \frac{1}{m} \sum_{i=1}^m \log P(x\mid z_i)\notag \\ \end{align} L=ELBo=DKL[q(zx)∣∣P(z)]Eq(zx)[logP(xz)]=DKL[q(zx)∣∣P(z)]m1i=1mlogP(xzi)
一般地, m = 1 m=1 m=1,损失函数可以进一步展开:
L = D K L [ q ( z ∣ x ) ∣ ∣ P ( z ) ] − log ⁡ P ( x ∣ z i ) = ∑ j = 1 d 1 2 ( − 1 + σ ′ ( j ) 2 + μ ′ ( j ) 2 − log ⁡ σ ′ ( j ) 2 ) − ( − 1 2 ∑ k = 1 K ( x ( k ) − μ ( k ) ) 2 σ ( k ) − log ⁡ ( 2 π ) K ∏ k = 1 K σ ( k ) ) \begin{align} \mathcal L &= D_{KL} [q(z\mid x) || P(z)] - \log P(x\mid z_i) \notag\\ &= \sum_{j=1}^d \frac{1}{2}(-1 + {\sigma'^{(j)}}^{2} + {\mu'^{(j)}}^{2} - \log {\sigma'^{(j)}}^{2})\notag \\ &\quad -\left( -\frac{1}{2} \sum_{k=1}^K \frac{(x^{(k)}-\mu^{(k)})^2}{\sigma^{(k)}} - \log \sqrt{(2 \pi)^{K}\prod_{k=1}^{K} \sigma^{(k)}} \right)\notag \end{align} L=DKL[q(zx)∣∣P(z)]logP(xzi)=j=1d21(1+σ(j)2+μ(j)2logσ(j)2) 21k=1Kσ(k)(x(k)μ(k))2log(2π)Kk=1Kσ(k)
上面提到过,在具体实现时来自解码器部分的 σ \sigma σ 会被认为是超参数,不妨令超参数 σ \sigma σ 为元素值全为 1 2 \frac{1}{2} 21 K K K 维向量。损失函数改写为:
L = 1 2 ∑ j = 1 d ( − 1 + σ ′ ( j ) 2 + μ ′ ( j ) 2 − log ⁡ σ ′ ( j ) 2 ) + ∥ x − μ ∥ 2 \mathcal{L} = \frac{1}{2}\sum_{j=1}^d (-1 + {\sigma'^{(j)}}^{2} + {\mu'^{(j)}}^{2} - \log {\sigma'^{(j)}}^{2}) + \|x - \mu^{}\|^2 L=21j=1d(1+σ(j)2+μ(j)2logσ(j)2)+xμ2
损失函数 L \mathcal L L 的第一项被认为是约束潜在空间带来的损失,第二项被认为是重建图像带来的损失。对第二项的理解是直观的,上面介绍了直接使用 μ \mu μ 作为重建图像的合理性,第二项是通过对应像素插值的平方和来评估重建带来的损失,这种做法非常常见。相比较而言,第二项就不那么直观了,接下来讨论第一项。图 1 1 1 中的符号与上式中的符号满足关系: exp ⁡ ( α ) = σ ′ 2 \exp(\alpha) = \sigma'^2 exp(α)=σ′2 o = μ ′ o = \mu' o=μ。在图 1 1 1 中之所以对 α \alpha α 取指数,是因为方差恒正,取指数操作可以避免通过添加激活函数保证编码器输出的正负性。以图 1 1 1 中的符号表示损失函数的第一项(的关键部分)为:
∑ j = 1 d ( − 1 + exp ⁡ ( α ( j ) ) + o ( j ) 2 − α ( j ) ) = ∑ j = 1 d ( exp ⁡ ( α ( j ) ) − ( 1 + α ( j ) ) + o ( j ) 2 ) \sum\limits_{j=1}^d (-1+\exp(\alpha^{(j)}) + {o^{(j)}}^2 - \alpha^{(j)}) \\ =\sum\limits_{j=1}^d (\exp(\alpha^{(j)}) -(1+ \alpha^{(j)}) + {o^{(j)}}^2) j=1d(1+exp(α(j))+o(j)2α(j))=j=1d(exp(α(j))(1+α(j))+o(j)2)
本质上,这才是 VAE 模型的损失函数。其中, o 2 o^2 o2 可以被认为是正则化项。 exp ⁡ ( α ) \exp(\alpha) exp(α) 1 + α 1+\alpha 1+α exp ⁡ ( α ) − ( 1 + α ) \exp(\alpha)-(1+\alpha) exp(α)(1+α) 分别在图 7 7 7 中由蓝线、绿线和红线表示。当 α = 0 \alpha=0 α=0 时,对应的方差为 exp ⁡ ( α ) = 1 \exp(\alpha)=1 exp(α)=1,此时的损失值为 exp ⁡ ( α ) − ( 1 + α ) = 0 \exp(\alpha)-(1+\alpha)=0 exp(α)(1+α)=0,即最低。可见,损失函数第一项中的 exp ⁡ ( α ) − ( 1 + α ) \exp(\alpha)-(1+\alpha) exp(α)(1+α) 部分保证在方差为 1 1 1 时损失最小而不是在方差为 0 0 0 时。如果不进行这样的约束,那么模型会倾向于学习到方差为 0 0 0,因为此时来自高斯分布的噪声 e e e 会失效,如此重建效果会更好,但是这也意味着模型几乎退化成朴素自动编码器模型,因此让方差为 1 1 1 时的损失最小是非常有意义的。

在这里插入图片描述

图 7    函数曲线

3.3. 重参数化技巧

在某个分布中采样是不涉及梯度计算的,也就无法实现反向传播,进而阻碍了训练的进行。编码器部分的输出是 μ ′ \mu' μ σ ′ \sigma' σ,解码器部分接收来自 N ( μ ′ , σ ′ ) N(\mu',\sigma') N(μ,σ) 的采样,可以想象虽然前向传播的过程不受影响,但是在反向传播计算梯度时会在此处卡住。重参数化技巧很好地解决了这个问题。重参数化技巧将采样过程与需要反向传播更新参数的计算过程解耦分离,原本是在分布 N ( μ ′ , σ ′ ) N(\mu', \sigma') N(μ,σ) 中直接进行采样,而现在则先从 N ( 0 , 1 ) N(0, 1) N(0,1) 中采样得到一个系数,将其与 σ ′ \sigma' σ 之积、与 μ ′ \mu' μ 之和作为潜在空间编码。很容易证明重参数化前后的采样都是来自相同的高斯分布。如图 8 8 8 所示。

在这里插入图片描述

图 8    重参数化技巧

REF

[1] Lecture 6 Extra 变分自编码器(Variational Auto-Encoder)- bilibili

[2] 17- Unsupervised Learning - Deep Generative Model (Part I)_- bilibili

[3] 一文搞懂变分自编码器(VAE, CVAE) - 简书

[4] 机器学习方法—优雅的模型(一):变分自编码器(VAE) - 知乎

[5] Understanding Variational Autoencoders (VAEs) | by Joseph Rocca | Towards Data Science

[6] VAE-变分自编码–原理一目了然 - 知乎

[7] 理解差分自动编码器 VAE:Variational AutoEncoder_vae 动机 - CSDN

[8] VAE原理详细解释(读书笔记)- CSDN

[9] 再谈变分自编码器(VAE):估计样本概率密度 - CSDN

[10] 变分自编码器VAE:原来是这么一回事 | 附开源代码 - 知乎

[11] 变分推断(Variational Inference)- CSDN

[12] 变分推断(Variational Inference)解析 - CSDN

[13] EM算法与变分推断_-CSDN

[14] 【机器学习】聚类【Ⅲ】高斯混合模型讲解_- CSDN

[15] 【机器学习】EM 算法 - CSDN

猜你喜欢

转载自blog.csdn.net/weixin_46221946/article/details/129845904