VAE以及tensorflow-2.0实现

Variational autoencoders(VAE)由Kingma et al.和Rezende et al.在2013年提出,它在图像生成、强化学习和自然语言处理等多个领域都有很广泛的应用。


在这里插入图片描述

下面的主要内容翻译自《Tutorial - What is a variational autoencoder?》,并同时加入自己的理解以及其他相关资料的补充。

VAE作为一种新的生成模型,相比于Autoencoder更加复杂一些,但是生成结果的质量自然更高一些。接下来我们从神经网络和概率模型两个角度分别对其进行阐述,并希望构建起连接两种角度的桥梁。

The neural net perspective

从VAE的模型架构来看并不复杂,它主要包含Encoder、Decoder和对应的损失函数。Encoder和Decoder都是用神经网络进行表示,其中Encoder和Decoder分别参数化表示为 q θ q_{\theta} p ϕ p_{\phi}


在这里插入图片描述

假设模型此刻所使用的数据集为MNIST,那么Encoder的输入 x x 就是一个 28 × 28 28 \times 28 的手写数字图像,即一个784维的向量。然后Encoder将 x x 转换到隐式表示空间(latent representation space) z z ,由于表示的稀疏性,我们希望 z z 可以捕获 x x 中重要的信息用于Decoder的重构,因此 z z 的维度要小于784。这里我们将Encoder的输入记为 q θ ( z x ) q_{\theta}(z|x) ,它表示为一个高斯概率密度函数,然后从中进行采样得到 z z

Decoder将 z z 做为此阶段的输入,并输出重构后的结果 p ϕ ( x z ) p_{\phi}(x|z) 。如果这里只使用手写数字图像的二值化形式,每个像素都可以使用一个伯努利分布表示。那么Decoder输出的784维的伯努利参数,每一个参数对应图像中的一个像素。但是由于Decoder接收到的并非原始输入完整的信息,因此重构后的结果会有一定比例的信息丢失。为了评估生成的结果,这里使用重构损失来评估 p ϕ ( x z ) p_{\phi}(x|z) x x

重构损失使用负对数似然,它主要起到一种类似于正则化器的作用。因为并不存在针对于所有数据点(datapoint)的全局信息,因此我们将损失函数解耦为针对单个数据点 l i l_{i} 的形式,最后将所有数据点的损失求和。 l i l_{i} 的损失表示为:
l i ( θ , ϕ ) = E z q θ ( z x i ) [ log p ϕ ( x i z ) ] + K L ( q θ ( z x i ) p ( z ) ) l_{i}(\theta,\phi)=-E_{z \sim q_{\theta}(z|x_{i})}[\log p_{\phi}(x_{i}|z)] + KL(q_{\theta}(z|x_{i}) || p(z))

其中第一项表示重构损失或第 i i 个数据点的负对数似然的期望,它用于迫使Decoder学会重构输入。如果重构结果并不好的话,从统计上来说,即它输出所满足的分布和真实数据所满足的分布相差较大。假设我们现在处理的是黑白图像,如果Decoder在本来为白色点的位置为黑色的点分配了较高的概率,则会导致最后重构的结果质量很低,同时带来的重构损失自然也就很大。

第二项计算 q θ ( z x ) q_{\theta}(z|x) p ( z ) p(z) 之间的KL散度,根据KL散度的原理可知,它表示了两者之间的接近程度,同时表示了在用 q q 表示 p p 的过程中丢失了多少信息。

VAE中 p p 通常满足标准正态分布 p ( z ) = N o r m a l ( 0 , 1 ) p(z) = Normal(0,1) ,如果Encoder的输出 z z p p 相差较远,那么损失函数中的惩罚项自然很大。这样做是为了使每个数字图像得到的 z z 的表示足够多样化,即对于同一个数字但由不同的人所写得到的图像应该是不同的,同时由于它们都是描述同一个数字,因此在 z z 的表示空间中又不会距离的很远。如果我们舍弃掉这个正则化项,那么encoder将为每个数字在欧式空间的不同区域提供一个表示,那么它就类似于autoencoder,模型是单纯的记住了训练样本,并不能从输入中抽取足够有用的信息。

在得到了VAE的损失函数后,我们便可以使用SGD来优化损失、训练模型。


The probability model perspective

下面我们从概率模型的角度来理解一下VAE,从概率模型的框架来说,VAE其实包含了一个特殊的关于 z z x x 的概率模型,我们可以用如下的联合概率来表示 z z x x
p ( x , z ) = p ( x z ) p ( z ) p(x,z) = p(x|z)p(z)

对于每一个数据点来说,生成过程可以简单的概括为:

  • 得到隐变量 z i p ( z ) z_{i} \sim p(z)
  • 得到重构的数据点 x i p ( x z ) x_{i} \sim p(x|z)

    在这里插入图片描述

在从先验分布 p ( z ) p(z) 中采样得到隐变量后,对于每一个数据点 x x 来说都拥有一个似然项 p ( x z ) p(x|z) 。因此,我们可以将上述的联合分布分解为关于先验分布和似然项表示的形式 p ( x , z ) = p ( x z ) p ( z ) p(x,z) = p(x|z)p(z)

现在我们只考虑Encoder所负责的推断(inference)阶段,目标是为给定的数据找到一个好的隐变量进行表示,根据贝叶斯公式可写作: p ( z x ) = p ( x z ) p ( z ) p ( x ) p(z|x) = \frac{p(x|z)p(z)}{p(x)}

其中 p ( x ) p(x) 这里称为evidence,我们可以通过对隐变量的边缘分布进行积分计算: p ( x ) = p ( x z ) p ( z ) d z p(x) = \int p(x|z)p(z)dz 但是直接计算会耗费指数级的时间,因为它需要对所有可能的隐变量进行计算。为了计算的方便,我们需要近似后验分布 p ( z x ) p(z|x)

VAE中使用变分推断来逼近 λ \lambda 参数化的后验分布族 q λ ( z x ) q_{\lambda}(z|x) λ \lambda 这里表示后验分布满足哪一种形式。如果 q q 这里满足高斯分布,那么 λ \lambda 这里表示的表示每一个数据点的均值和方差 λ x i = ( μ x i , σ x i 2 ) \lambda_{x_{i}} = (\mu_{x_{i}},\sigma^2_{x_{i}})

那么我们需要知道的是现在的变分后验分布 q λ ( z x ) q_{\lambda}(z|x) 有多接近于 p ( z x ) p(z|x) ,同样这里可以使用KL散度进行描述: K L ( q λ ( z x ) p ( z x ) ) = E q [ log q λ ( z x ) ] E q [ log p ( x , z ) ] + log p ( x ) KL(q_{\lambda}(z|x) | p(z|x)) = E_{q}[\log q_{\lambda}(z|x)] - E_{q}[\log p(x,z)] + \log p(x)

此时的目标便是找到合适的 λ \lambda 或变分后验分布来最小化上述的KL散度,那么最优的变分后验分布可以表示为 q λ ( z x ) = arg min λ K L ( q λ ( z x ) p ( z x ) ) q^*_{\lambda}(z|x) = \text{arg} \min_{\lambda} KL(q_{\lambda}(z|x) || p(z|x))

由于数据点 x x 所满足的分布 p ( x ) p(x) 我们是无法得知的,这也就导致了上面个的KL散度无法直接计算。因此,我们需要更多的东西来处理变分推断。假设现在存在函数 E L B O ( λ ) = E q [ log p ( x , z ) ] E q [ log q λ ( z x ) ] ELBO(\lambda) = E_{q}[ \log p(x,z)] - E_{q}[\log q_{\lambda}(z|x)]
那么将其带入关于KL散度的计算式中有 log p ( x ) = E L B O ( λ ) + K L ( q λ ( z x ) p ( z x ) ) \log p(x) = ELBO(\lambda) + KL(q_{\lambda}(z|x) || p(z|x))
同时有 K L ( q λ ( z x ) p ( z x ) ) = log p ( x ) E L B O ( λ ) KL(q_{\lambda}(z|x) || p(z|x)) =\log p(x) - ELBO(\lambda)
根据Jensen不等式可知,KL散度是恒大于等于零的,因此最小化KL散度等价于最大化ELBO。因此,我们可以使用ELBO(Evidence Lower Bound)进行近似的后验推断,此时不必计算 q λ q_{\lambda} p ( z x ) p(z|x) 之间的KL散度,而转去最大化等效的ELBO,而它在计算上是可以处理的。

在VAE中只存在局部的隐变量,即每个数据点都有一个对应的 z z ,因此我们可以将整体的计算分解对单个数据点的计算,最后求和即可。这样分解为单个数据点计算的方式也方便我们根据 λ \lambda 使用随机梯度下降进行参数的更新。

NOTE: λ \lambda 这里对所有的数据点是共享的

那么每一个数据点的ELBO表示为 E L B O i ( λ ) = E q λ ( z x i ) [ log p ( x i z ) ] K L ( q λ ( z x i ) p ( z ) ) ELBO_{i}(\lambda) = Eq_{\lambda}(z|x_{i})[ \log p(x_{i}|z)] - KL(q_{\lambda}(z|x_{i}) || p(z))
这里的 E L B O i ( λ ) ELBO_{i}(\lambda) 即下面所要讨论的损失函数 l i ( θ , ϕ ) -l_{i}(\theta,\phi) 的负数形式。我们同样可以将上式的第一项看作重构损失,将第二项看作正则项。

变分推断在最大化ELBO时针对的是变分参数 λ \lambda ,同时推断的过程也是针对于参数 ϕ \phi 。这中过程也称为变分EM(variational EM),因为整个过程是针对模型参数 ϕ \phi 来最大化数据的对数似然期望。

最后我们便可以使用变分推断算法来学习变分参数 λ \lambda ,以及使用变分EM学习模型参数 ϕ \phi


The reparametrization trick

如果我们从 q θ ( z x ) q_{\theta}(z|x) 中采样得到 z z ,那么如何计算 z z 相对于 θ \theta 的导数呢?直觉上来看,当 z z 是固定时,它的导数就应该是非零的。

对于某些分布可以使用一种巧妙的方法对样本进行重新参数化表示,使得随机性与参数无关。我们想让样本确定性的依赖于分布的参数,例如在正态分布中,参数为 μ \mu σ \sigma ,那么样本可以表示为 z = μ + σ ϵ z = \mu + \sigma \odot \epsilon 其中 ϵ N o e m a l ( 0 , 1 ) \epsilon \sim Noemal(0,1)


在这里插入图片描述

其中上图中的均值和方差来源于 θ \theta 参数化的Encoder的输出。


那么如何用代码实现VAE呢?我们仍然可以使用上篇AutoEncoder以及TensorFlow-2.0实现代码中的网络模型,不同之处在于encode函数和decode函数的实现。

    def encode(self,x):
        mean_logvar = self.encoder(x)
        N = mean_logvar.shape[0] 
        mean = tf.slice(mean_logvar, [0, 0], [N, self.latent_dim])
        logvar = tf.slice(mean_logvar, [0, self.latent_dim], [N, self.latent_dim])
        return mean,logvar
        

    def decode(self,z,apply_sigmoid = False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits
      
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

同样在训练阶段,由于损失函数的不同,在计算损失时也和auto-encoder不一样。

# training
class train:

    @staticmethod
    def compute_loss(model,x):
        mean, logvar = model.encode(x)
        z = model.reparameterize(mean,logvar)
        x_logits = model.decode(z)
        
        # loss
        cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logits, labels=x)
        marginal_likelihood = - tf.reduce_sum(cross_ent, axis=[1, 2, 3])
        marginal_likelihood = tf.reduce_mean(marginal_likelihood)

        KL_divergence = tf.reduce_sum(mean ** 2 + tf.exp(logvar) - logvar - 1, axis=1)
        KL_divergence = tf.reduce_mean(KL_divergence)

        ELBO = marginal_likelihood - KL_divergence
        loss = -ELBO
        return loss
    
    @staticmethod
    def compute_gradient(model,x):
        with tf.GradientTape() as tape:
            loss = train.compute_loss(model,x)
        gradient = tape.gradient(loss,model.trainable_variables)
        return gradient,loss
    
    @staticmethod
    def update(optimizer,gradients,variables):
        optimizer.apply_gradients(zip(gradients,variables))

完整的实现代码:

# -*- coding: utf-8 -*-
"""
Created on Wed Sep  4 09:58:57 2019

@author: dyliang
"""

from __future__ import absolute_import,print_function,division
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
import time
import plot


# vae
class vae(keras.Model):
    
    def __init__(self,latent_dim):
        super(vae,self).__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = keras.Sequential([
                keras.layers.InputLayer(input_shape = (28,28,1)),
                keras.layers.Conv2D(filters = 32,kernel_size = 3,strides = (2,2),activation = 'relu'),
                keras.layers.Conv2D(filters = 32,kernel_size = 3,strides = (2,2),activation = 'relu'),
                keras.layers.Flatten(),
                keras.layers.Dense(256,activation = 'relu'),
                keras.layers.Dense(self.latent_dim + self.latent_dim)
                ])
        self.decoder = keras.Sequential([
                keras.layers.InputLayer(input_shape = (latent_dim,)),
                keras.layers.Dense(units = 7 * 7 * 32,activation = 'relu'),
                keras.layers.Reshape(target_shape = (7,7,32)),
                keras.layers.Conv2DTranspose(
                        filters = 64,
                        kernel_size = 3,
                        strides = (2,2),
                        padding = "SAME",
                        activation = 'relu'),
                keras.layers.Conv2DTranspose(
                        filters = 32,
                        kernel_size = 3,
                        strides = (2,2),
                        padding = "SAME",
                        activation = 'relu'),
                keras.layers.Conv2DTranspose(
                        filters = 1,
                        kernel_size = 3,
                        strides = (1,1),
                        padding = "SAME"),
                keras.layers.Conv2DTranspose(
                        filters = 1,
                        kernel_size = 3,
                        strides = (1,1),
                        padding = "SAME",
                        activation = 'sigmoid'),
                ])
                

    def encode(self,x):
        mean_logvar = self.encoder(x)
        N = mean_logvar.shape[0] 
        mean = tf.slice(mean_logvar, [0, 0], [N, self.latent_dim])
        logvar = tf.slice(mean_logvar, [0, self.latent_dim], [N, self.latent_dim])
        return mean,logvar
        

    def decode(self,z,apply_sigmoid = False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits
      
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

# training
class train:

    @staticmethod
    def compute_loss(model,x):
        mean, logvar = model.encode(x)
        z = model.reparameterize(mean,logvar)
        x_logits = model.decode(z)
        
        # loss
        cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logits, labels=x)
        marginal_likelihood = - tf.reduce_sum(cross_ent, axis=[1, 2, 3])
        marginal_likelihood = tf.reduce_mean(marginal_likelihood)

        KL_divergence = tf.reduce_sum(mean ** 2 + tf.exp(logvar) - logvar - 1, axis=1)
        KL_divergence = tf.reduce_mean(KL_divergence)

        ELBO = marginal_likelihood - KL_divergence
        loss = -ELBO
        return loss
    
    @staticmethod
    def compute_gradient(model,x):
        with tf.GradientTape() as tape:
            loss = train.compute_loss(model,x)
        gradient = tape.gradient(loss,model.trainable_variables)
        return gradient,loss
    
    @staticmethod
    def update(optimizer,gradients,variables):
        optimizer.apply_gradients(zip(gradients,variables))
        
# hpy
latent_dim = 100
num_epochs = 100
lr = 1e-4
batch_size = 1000
train_buf = 60000
test_buf = 10000

# load data
def load_data(batch_size):
    mnist = keras.datasets.mnist
    (train_data,train_labels),(test_data,test_labels) = mnist.load_data()
    
    train_data = train_data.reshape(train_data.shape[0],28,28,1).astype('float32') / 255.
    test_data = test_data.reshape(test_data.shape[0],28,28,1).astype('float32') / 255.
    
    train_data = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
    train_labels = tf.data.Dataset.from_tensor_slices(train_labels).batch(batch_size)
    train_dataset = tf.data.Dataset.zip((train_data,train_labels)).shuffle(train_buf)
    
    test_data = tf.data.Dataset.from_tensor_slices(test_data).batch(batch_size)
    test_labels = tf.data.Dataset.from_tensor_slices(test_labels).batch(batch_size)
    test_dataset = tf.data.Dataset.zip((test_data,test_labels)).shuffle(test_buf)
    
    return train_dataset,test_dataset
    


# begin training
def begin():
    train_dataset,test_dataset = load_data(batch_size)
    model = vae(latent_dim)
    optimizer = keras.optimizers.Adam(lr)
    
    for epoch in range(num_epochs):
        start = time.time()
        last_loss = 0
        for train_x,_ in train_dataset:
            gradients,loss = train.compute_gradient(model,train_x)
            train.update(optimizer,gradients,model.trainable_variables)
            last_loss = loss
        # if epoch % 10 == 0:
        print ('Epoch {},loss: {},Remaining Time at This Epoch:{:.2f}'.format(
            epoch,last_loss,time.time()-start))

    plot.plot_VAE(model, test_dataset)

if __name__ == '__main__':
    begin()

输出结果:


在这里插入图片描述
在这里插入图片描述

发布了267 篇原创文章 · 获赞 91 · 访问量 19万+

猜你喜欢

转载自blog.csdn.net/Forlogen/article/details/100554185
VAE