Tensorflow variation from the encoder: Reconstruction of the image and generating Fashion MNIST

Fashion MNIST input vector image, after three fully connected layers obtained mean and variance of the implicit vector z, respectively two output nodes is represented by a fully connected layer 20, FC2 output nodes 20 represent feature profiles 20 mean vector μ, fc3 output nodes 20 represent the log of the variance of feature vectors 20 distribution. Obtained by sampling Reparameterization trick implicit vector length z 20 and the sample image reconstructed by fc4 / fc5.

As VAE model generation, the input samples may be reconstructed in addition, the decoder may also be used to generate a separate sample. Obtaining vector z by direct sampling hidden from the prior distribution P (z), after the decoded samples generated can be produced.

Code

import tensorflow as tf 
from tensorflow import keras
import numpy as np
from    matplotlib import pyplot as plt
from    PIL import Image


(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = tf.convert_to_tensor(x_train/255., tf.float32)
x_test = tf.convert_to_tensor(x_test/255., tf.float32) 

batchsz = 100
train_db = tf.data.Dataset.from_tensor_slices(x_train)
test_db = tf.data.Dataset.from_tensor_slices(x_test) 

train_db = train_db.shuffle(batchsz*5).batch(batchsz).repeat(10)
test_db = test_db.batch(batchsz)


class VAE(keras.Model):
    # 变分自编码器
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder网络
        self.fc1 = keras.layers.Dense(128)
        self.fc2 = keras.layers.Dense(20)
        self.fc3 = keras.layers.Dense(20)
        # Decoder网络
        self.fc4 = keras.layers.Dense(128)
        self.fc5 = keras.layers.Dense(784)
    
    def encoder(self, x):
        h = tf.nn.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var)**0.5
        z = mu + std*eps
        return z

    def decoder(self, z):
        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)
        return out

    def call(self, inputs, training=None):
        mu, log_var = self.encoder(inputs)
        z = self.reparameterize(mu, log_var)

        x_hat = self.decoder(z)
        return x_hat, mu, log_var

model = VAE()
model.build(input_shape=(4,784))
model.summary()


optimizer = keras.optimizers.Adam(learning_rate=1e-3) 
for step, x in enumerate(train_db):
    x = tf.reshape(x, [-1,784])
    with tf.GradientTape() as tape:
        x_rec_logits, mu, log_var = model(x)
        rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
        rec_loss = tf.reduce_sum(rec_loss) / x.shape[0] 

        kl_div = -0.5 * (log_var + 1 - mu**2 - tf.exp(log_var))                       
        kl_div = tf.reduce_sum(kl_div) / x.shape[0]
        loss = rec_loss + 1. * kl_div 

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables)) 

    if step%100 == 0:
        print(step, 'kl div: ', float(kl_div), 'loss: ', float(loss))

def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))
    index = 0
    for i in range(0, 280, 28): # 10 行图片阵列
        for j in range(0, 280, 28): # 10 列图片阵列
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j)) # 写入对应位置
            index += 1

    # 保存图片阵列
    new_im.save(name) 

z = tf.random.normal((100, 20))
logits = model.decoder(z)
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1,28,28]).numpy() *255.
x_hat = x_hat.astype(np.uint8)
save_images(x_hat, 'vaebuild.png')

x = next(iter(test_db))
logits, _, _ = model(tf.reshape(x, [-1,784]))
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1,28,28]) 

x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
x_concat = x_concat.numpy() * 255. 
x_concat = x_concat.astype(np.uint8)
save_images(x_concat,'10_vae.png')

 

 

Published 93 original articles · won praise 2 · Views 3016

Guess you like

Origin blog.csdn.net/qq_40041064/article/details/104916134