Keras-CGAN_MNIST 代码解读

最近看了CGAN的论文,2014年的论文,短小精悍,CGAN可以用于图像修补,多模态识别,感觉很有意思。抽空会把CGAN的论文理解也放上来。

论文下载地址:Conditional Generative Adversarial Nets

先放入全部代码。来源:【Keras-CGAN】MNIST / CIFAR-10

代码中噪声Z和label、输入图片和label的combine机制和论文中不同,感觉没有达到论文中的效果,不过也很好。但是论文中的机制很复杂,入门用这个就能跑出较好的效果。这份代码的网络结构是多层感知器比较简单,没有用上卷积层,如果采用DCGAN 的结构可能效果会更好。

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

# build_generator
model = Sequential()

model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
model.add(Reshape((28, 28, 1)))

model.summary()

noise = Input(shape=(100,))  # input 100,这里写成100不加逗号不行哟
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label))  # class, z dimension

model_input = multiply([noise, label_embedding])  # 把 label 和 noise embedding 在一起,作为 model 的输入
print(model_input.shape)

img = model(model_input)  # output (28,28,1)

generator = Model([noise, label], img)

# build_discriminator
model = Sequential()

model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))

model.add(Dense(1, activation='sigmoid'))
model.summary()

img = Input(shape=(28,28,1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')

label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])

validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作为 model 的输入
discriminator = Model([img, label], validity)


#compile model
optimizer = Adam(0.0002, 0.5)

# discriminator
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])


# The combined model  (stacked generator and discriminator)
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise,label])

# For the combined model we will only train the generator
validity = discriminator([img,label])
discriminator.trainable = False

# Trains the generator to fool the discriminator
combined = Model([noise,label], validity)
combined.summary()
combined.compile(loss='binary_crossentropy',
                 optimizer=optimizer)

def sample_images(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    sampled_labels = np.arange(0, 10).reshape(-1, 1)
    gen_imgs = generator.predict([noise, sampled_labels])
    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist%d.png" % epoch)
    plt.close()

batch_size = 32
sample_interval = 200

# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data() # (60000,28,28)
# Rescale -1 to 1
X_train = X_train / 127.5 - 1. # tanh 的结果是 -1~1,所以这里 0-1 归一化后减1
X_train = np.expand_dims(X_train, axis=3)  # (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(50001):
    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random batch of images
    idx = np.random.randint(0, X_train.shape[0], batch_size) # 0-60000 中随机抽
    #imgs = X_train[idx]
    imgs, labels = X_train[idx], y_train[idx]
    noise = np.random.normal(0, 1, (batch_size, 100))# 生成标准的高斯分布噪声

    # Generate a batch of new images
    gen_imgs = generator.predict([noise,labels])

    # Train the discriminator
    d_loss_real = discriminator.train_on_batch([imgs, labels], valid) #真实数据对应标签1
    d_loss_fake = discriminator.train_on_batch([gen_imgs,labels], fake) #生成的数据对应标签0
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------
    #noise = np.random.normal(0, 1, (batch_size, 100))
    sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
    # Train the generator (to have the discriminator label samples as valid)
    g_loss = combined.train_on_batch([noise, sampled_labels], valid)

    # Plot the progress
    if epoch % 200==0:
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

    # If at save interval => save generated image samples
    if epoch % sample_interval == 0:
        sample_images(epoch)

主要解读为GAN 的G网和D网的输入都添加条件信息的部分(add  label)

1.G: G网的输入噪声z要结合label

model  定义了一个基于多层感知器的G网结构,然后

noise = Input(shape=(100,))  # input 100,这里写成100不加逗号不行哟
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label))  # class, z dimension

model_input = multiply([noise, label_embedding])  # 把 label 和 noise embedding 在一起,作为 model 的输入
print(model_input.shape)

img = model(model_input)  # output (28,28,1)

generator = Model([noise, label], img)

主要是embedding层的理解,可查看官方文档和相关博客。

  • label_embedding 把“词汇表”大小为10的label(一共10个类别) 转换为100的向量维度,和noise维度一样。
  • Flatten层将输入进行一维化
  • Multiply层计算输入张量列表的(逐元素间的)乘积。将label和噪声Z结合。它接受一个张量的列表, 所有的张量必须有相同的输入尺寸, 然后返回一个张量(和输入张量尺寸相同)。因此,上一步把label转为和noise一样维度。
  • img = model(model_input)  # output (28,28,1),生成一个图片
  • 由于以上的融合label的操作,使得G网的模型定义为:

generator = Model([noise, label], img)   #定义了最终G网的结构

  • Keras有两种类型的模型:序贯模型(Sequential)和函数式模型(Model)
  • Model(inputs, outputs)  generator = Model([noise, label], img)。 G网的输出还是img大小(28,28,1)

2.D:    D网的输入img(真实or生成的图片)要结合label

img = Input(shape=(28,28,1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')

label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])

validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作为 model 的输入
discriminator = Model([img, label], validity)  #定义了最终D网的结构
  • np.prod()函数用来计算所有元素的乘积,所以label_embedding把“词汇表”大小为10的label(一共10个类别)转换成了28*28*1维,Flatten 转为一维。Multiply也是将label和输入的img结合。
  • validity = model(model_input)  利用D网定义的model给输入的label和图片的结合进行打分,判断真假。
  • discriminator = Model([img, label], validity)  #定义了最终D网的结构
  • 剩下的model complie, combined model,训练过程包括损失函数设计都和dcgan的设计一致,只是输入的部分时候要加上label
  • sampled_labels = np.arange(0, 10).reshape(-1, 1)
  • gen_imgs = generator.predict([noise, sampled_labels])
  • 结果是生成了label为0-9的图片。

最后放一张,50000次迭代后的生成图片

由于这个代码的G网,D网结构没有采用卷积层,是多层感知器的结构(MLP),所以效果不太好,改成DCGAN 的结构可能效果会好很多。

扫描二维码关注公众号,回复: 10350464 查看本文章
发布了10 篇原创文章 · 获赞 10 · 访问量 7514

猜你喜欢

转载自blog.csdn.net/qq_41647438/article/details/102939588