[GAN] Four, CGAN paper detailed explanation and code detailed explanation

foreword

Since the internship in Guangzhou ended on October 15th, due to the school's various internship-related procedures, defense, and catching up with the rhythm of the graduation thesis opening for nearly a month, the relevant notes were not in time before the end of the internship. Starting today, updates to related blogs will resume.

Before we introduced the relevant theory of DCGAN and the original GAN, and gave the code for DCGAN to generate handwritten digital images. If you are interested, please follow the links below:

  1. [GAN] 1. Using keras to realize DCGAN to generate handwritten digital images
  2. [GAN] 2. Detailed explanation of the original GAN ​​paper
  3. [GAN] 3. Detailed explanation of DCGAN papers

In this blog we will introduce the relevant details of the CGAN (Conditional GAN) paper . Please move to CGAN's paper URL: Conditional Generative Adversarial Nets . CGAN generates keras code for handwritten digits, please go to: CGAN-mnist


1. GAN review

In order to take into account the theoretical introduction of CGAN, we first review the relevant details of GAN. GAN mainly includes two networks, one is the generator GGG and discriminatorDDD , the purpose of the generator is to map the randomly input Gaussian noise into an image ("false picture"), and the discriminator is to judge the probability of whether the input image comes from the generator, that is, the probability of judging whether the input image is a fake picture.

Here we assume the data is xxx , the data distribution of the generator ispg p_gpg, the noise distribution is pz ( z ) p_z(z)pz( z ) , then the noisezzThe result of z can be recorded asG ( z ; θ g ) G(z;\theta_g)G(z;ig) , dataxxx in discriminatorDDThe result on D is D ( x ; θ d ) D(x;\theta_d)D(x;id)

Then the purpose of GAN is to create something out of nothing, to confuse the real with the fake. That is, to make the generator GGThe so-called "false image" generated by G fools the discriminatorDDD , then the optimal state is the generatorGGThe so-called "false graph" generated by G is in the discriminatorDDThe discriminant result of D is 0.5, and it is not known whether it is a real picture or a fake picture. The objective function of GAN is as follows:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ pdata ( x ) [ log ⁡ D ( x ) ] + E z ∼ pdata ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] (1) \underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={ {\ mathbb {E}}_{x\sim { {p}_{data}}(x)}}[\log D(x)]+{ { \mathbb{E}}_{z\sim { {p}_ {data}}(z)}}[\log (1-D(G(z)))]\tag1GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpdata(z)[log(1D(G(z)))](1)


2. Detailed explanation of CGAN network architecture

After introducing the principle of CGAN, the relevant principles of CGAN are introduced next. The original GAN ​​generator can only generate images based on random noise. As for what the image is (that is, we have no way of knowing what the label is), the discriminator can only receive image input to determine whether the image is used to make the generator. Therefore, the main contribution of CGAN is to add additional information yy to the input of the original GAN ​​generator and discriminatory . Additionalinformationyyy can be any information, such as a label. Therefore, the proposal of CGAN enables GAN to use images and corresponding labels for training, and
use given labels to generate specific images in the test phase.

In the CGAN paper, the MLP (fully connected network) used by the network architecture. In the generator in CGAN, we are given an input noise pz ( z ) p_z(z)pz( z ) and additional informationyyy , and then the two are connected together through a fully connected layer as the input of the hidden layer. Similarly, the input imagexxx and additional informationyyy will also be concatenated together as the hidden layer input. The network architecture diagram of CGAN is as follows:

insert image description here
Then, the objective function of CGAN can be expressed as follows:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ pdata ( x ) [ log ⁡ D ( x ∣ y ) ] + E z ∼ pdata ( z ) [ log ⁡ ( 1 − D ( G ( z ∣ y ) ) ) ] (2) \underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V (D,G)={ {\mathbb{E}}_{x\sim { {p}_{data}}(x)}}[\log D(x|y)]+{ { \mathbb{E }}_{z\sim { {p}_{data}}(z)}}[\log (1-D(G(z|y)))]\tag2GminDmaxV(D,G)=Expdata(x)[logD(xy)]+Ezpdata(z)[log(1D(G(zy)))](2)

The following is the result of the handwritten digital image generated in the CGAN paper. Each row represents a label, for example, the first row represents a picture with a label of 0.
insert image description here


3. Detailed explanation of CGAN-MNIST code

Next we will mainly introduce the keras code for CGAN to generate handwritten digital images. The github link is: CGAN-mnist . First give the network architecture code of CGAN:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 13:39
# @Author  : Dai PuWei
# @File    : CGAN.py
# @Software: PyCharm

import os
import cv2
import numpy as np
import datetime
import matplotlib.pyplot as plt

from scipy.stats import truncnorm


from keras import Input
from keras import Model
from keras import Sequential

from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Reshape
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.merge import multiply
from keras.layers.merge import concatenate
from keras.layers.merge import add
from keras.layers import Embedding
from keras.utils import to_categorical
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar
from copy import deepcopy
from keras.datasets import mnist

def make_trainable(net, val):
    """ Freeze or unfreeze layers
    """
    net.trainable = val
    for l in net.layers: l.trainable = val

class CGAN(object):

    def __init__(self,config,weight_path=None):
        """
        这是CGAN的初始化函数
        :param config: 参数配置类实例
        :param weight_path: 权重文件地址,默认为None
        """
        self.config = config
        self.build_cgan_model()

        if weight_path is not None:
            self.cgan.load_weights(weight_path,by_name=True)

    def build_cgan_model(self):
        """
        这是搭建CGAN模型的函数
        :return:
        """
        # 初始化输入
        self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,))
        self.condational_label_input = Input(shape=(1,), dtype='int32')
        self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim)

        # 定义优化器
        self.optimizer = Adam(lr=2e-4, beta_1=0.5)

        # 构建生成器模型与判别器模型
        self.discriminator_model = self.build_discriminator_model()
        self.discriminator_model.compile(optimizer=self.optimizer, loss=['binary_crossentropy'],metrics=['accuracy'])
        self.generator_model = self.build_generator()

        # 构建CGAN模型
        self.discriminator_model.trainable = False
        self.cgan_input = [self.generator_noise_input,self.condational_label_input]
        generator_output = self.generator_model(self.cgan_input)
        cgan_output = self.discriminator_model([generator_output,self.condational_label_input])
        self.cgan = Model(self.cgan_input,cgan_output)

        # 编译
        #self.discriminator_model.compile(optimizer=self.optimizer,loss='binary_crossentropy')
        self.cgan.compile(optimizer=self.optimizer,loss=['binary_crossentropy'])

    def build_discriminator_model(self):
        """
        这是搭建生成器模型的函数
        :return:
        """
        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.config.discriminator_image_input_dim)))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.config.discriminator_image_input_dim)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.config.condational_label_num,
                                              np.prod(self.config.discriminator_image_input_dim))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        validity = model(model_input)

        return Model([img, label], validity)


    def build_generator(self):
        """
        这是构建生成器网络的函数
        :return:返回生成器模型generotor_model
        """
        model = Sequential()

        model.add(Dense(256, input_dim=self.config.generator_noise_input_dim))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(np.prod(self.config.discriminator_image_input_dim), activation='tanh'))
        model.add(Reshape(self.config.discriminator_image_input_dim))

        model.summary()

        noise = Input(shape=(self.config.generator_noise_input_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def train(self, train_datagen, epoch, k, batch_size=256):
        """
        这是DCGAN的训练函数
        :param train_generator:训练数据生成器
        :param epoch:周期数
        :param batch_size:小批量样本规模
        :param k:训练判别器次数
        :return:
        """
        time =datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        model_path = os.path.join(self.config.model_dir,time)
        if not os.path.exists(model_path):
            os.mkdir(model_path)

        train_result_path = os.path.join(self.config.train_result_dir,time)
        if not os.path.exists(train_result_path):
            os.mkdir(train_result_path)

        for ep in np.arange(1, epoch+1).astype(np.int32):
            cgan_losses = []
            d_losses = []
            # 生成进度条
            length = train_datagen.batch_num
            progbar = Progbar(length)
            print('Epoch {}/{}'.format(ep, epoch))
            iter = 0
            while True:
                # 遍历一次全部数据集,那么重新来结束while循环
                #print("iter:{},{}".format(iter,train_datagen.get_epoch() != ep))
                if train_datagen.epoch != ep:
                    break

                # 获取真实图片,并构造真图对应的标签
                batch_real_images, batch_real_labels = train_datagen.next_batch()
                batch_real_num_labels = np.ones((batch_size, 1))
                #batch_real_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                # 初始化随机噪声,伪造假图,并合并真图和假图数据集
                batch_noises = np.random.normal(0, 1, size = (batch_size, self.config.generator_noise_input_dim))
                d_loss = []
                for i in np.arange(k):
                    # 构造假图标签,合并真图和假图对应标签
                    batch_fake_num_labels = np.zeros((batch_size,1))
                    #batch_fake_num_labels = truncnorm.rvs(0.0, 0.3, size=(batch_size, 1))
                    batch_fake_labels = deepcopy(batch_real_labels)
                    batch_fake_images = self.generator_model.predict([batch_noises,batch_fake_labels])

                    # 训练判别器
                    real_d_loss = self.discriminator_model.train_on_batch([batch_real_images,batch_real_labels],
                                                                                      batch_real_num_labels)
                    fake_d_loss = self.discriminator_model.train_on_batch([batch_fake_images, batch_fake_labels],
                                                                                      batch_fake_num_labels)
                    d_loss.append(list(0.5*np.add(real_d_loss,fake_d_loss)))
                #print(d_loss)
                d_losses.append(list(np.average(d_loss,0)))
                #print(d_losses)

                # 生成一个batch_size的噪声来训练生成器
                #batch_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                batch_num_labels = np.ones((batch_size,1))
                batch_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
                cgan_loss = self.cgan.train_on_batch([batch_noises,batch_labels], batch_num_labels)
                cgan_losses.append(cgan_loss)

                # 更新进度条
                progbar.update(iter, [('dcgan_loss', cgan_losses[iter]),
                                      ('discriminator_loss',d_losses[iter][0]),
                                      ('acc',d_losses[iter][1])])
                #print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (ep, d_losses[ep][0], 100 * d_losses[ep][1],cgan_loss))
                iter += 1
            if ep % self.config.save_epoch_interval == 0:
                model_cgan = "Epoch{}dcgan_loss{}discriminator_loss{}acc{}.h5".format(ep, np.average(cgan_losses),
                                                                                      np.average(d_losses,0)[0],np.average(d_losses,0)[1])
                self.cgan.save(os.path.join(model_path, model_cgan))
                save_dir = os.path.join(train_result_path, str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(int(ep), save_dir)
            '''
            if int(ep) in self.config.generate_image_interval:
                save_dir = os.path.join(train_result_path,str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(ep,save_dir)
            '''
        plt.plot(np.arange(epoch),cgan_losses,'b-','cgan-loss')
        plt.plot(np.arange(epoch), d_losses[0], 'b-', 'd-loss')
        plt.grid(True)
        plt.legend(locs="best")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(os.path.join(train_result_path,"loss.png"))

    def save_image(self, epoch,save_path):
        """
        这是保存生成图片的函数
        :param epoch:周期数
        :param save_path: 图片保存地址
        :return:
        """
        rows, cols = 10, 10

        fig, axs = plt.subplots(rows, cols)
        for i in range(rows):
            label = np.array([i]*rows).astype(np.int32).reshape(-1,1)
            noise = np.random.normal(0, 1, (cols, 100))
            images = self.generator_model.predict([noise,label])
            images = 127.5*images+127.5
            cnt = 0
            for j in range(cols):
                #img_path = os.path.join(save_path, str(cnt) + ".png")
                #cv2.imwrite(img_path, images[cnt])
                #axs[i, j].imshow(image.astype(np.int32)[:,:,0])
                axs[i, j].imshow(images[cnt,:, :, 0].astype(np.int32), cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(save_path, "mnist-{}.png".format(epoch)), dpi=600)
        plt.close()

    def generate_image(self,label):
        """
        这是伪造一张图片的函数
        :param label:标签
        """
        noise = truncnorm.rvs(-1, 1, size=(1, self.config.generator_noise_input_dim))
        label = np.array([label]).T
        image = self.generator_model.predict([noise,label])[0]
        image = 127.5*(image+1)
        return image


In order to train, we must also construct a dataset iterator to read small batches of handwritten digital image data. The code of the dataset iterator class is as follows:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 17:29
# @Author  : Dai PuWei
# @File    : MnistGenerator.py
# @Software: PyCharm

import math
import numpy as np
from keras.datasets import mnist

class MnistGenerator(object):

    def __init__(self,batch_size):
        """
        这是图像数据生成器的初始化函数
        :param batch_size: 小批量样本规模
        """
        (x_train,y_train),(x_test,y_test) = mnist.load_data()
        #self.x = np.concatenate([x_train,x_test]).astype(np.float32)
        self.x = np.expand_dims((x_train.astype(np.float32)-127.5)/127.5,axis=-1)
        #self.y = to_categorical(np.concatenate([y_train,y_test]),num_classes=10)
        self.y = y_train.reshape(-1,1)
        #self.y = self.y[y == ]
        #print(np.shape(self.x))
        #print(np.shape(self.y))
        self.images_size = len(self.x)
        random_index = np.random.permutation(np.arange(self.images_size))
        self.x = self.x[random_index]
        self.y = self.y[random_index]

        self.epoch = 1                                  # 当前迭代次数
        self.batch_size = int(batch_size)
        self.batch_num = math.ceil(self.images_size / self.batch_size)
        self.start = 0
        self.end = 0
        self.finish_flag = False                        # 数据集是否遍历完一次标志

    def _next_batch(self):
        """
        :return:
        """
        while True:
            #batch_images = np.array([])
            #batch_labels = np.array([])
            if self.finish_flag:  # 数据集遍历完一次
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                self.finish_flag = False
                self.epoch += 1
            self.end = int(np.min([self.images_size,self.start+self.batch_size]))
            batch_images = self.x[self.start:self.end]
            batch_labels = self.y[self.start:self.end]
            batch_size = self.end - self.start
            if self.end == self.images_size:            # 数据集刚分均分
                self.finish_flag = True
            if batch_size < self.batch_size:        # 小批次规模小于与预定规模,基本上是最后一组
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                batch_images = np.concatenate((batch_images, self.x[0:self.batch_size - batch_size]))
                batch_labels = np.concatenate((batch_labels, self.y[0:self.batch_size - batch_size]))
                self.start = self.batch_size - batch_size
                self.epoch += 1
            else:
                self.start = self.end
            yield batch_images,batch_labels

    def next_batch(self):
        datagen = self._next_batch()
        return datagen.__next__()

The following is the code for the relevant training CGAN:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 15:43
# @Author  : Dai PuWei
# @File    : train.py
# @Software: PyCharm

import os
import datetime

from CGAN.CGAN import CGAN
from Config.Config import MnistConfig
from DataGenerator.MnistGenerator import MnistGenerator

def run_main():
    """
    这是主函数
    """
    cfg =  MnistConfig()
    cgan = CGAN(cfg)
    batch_size = 512
    #train_datagen = Cifar10Generator(int(batch_size/2))
    train_datagen = MnistGenerator(batch_size)
    cgan.train(train_datagen,100000,1,batch_size)


if __name__ == '__main__':
    run_main()

Below are images of handwritten digits generated by CGAN during training. Generation results after the first epoch:
insert image description here
Generation results after the 10th epoch:
insert image description here
Generation results after the 100th epoch:
insert image description here
Generation results after the 1000th epoch:
insert image description here
The following is the test code of CGAN:

# -*- coding: utf-8 -*-
# @Time    : 2019/11/8 13:11
# @Author  : DaiPuWei
# @Email   : [email protected]
# @File    : test.py
# @Software: PyCharm


import os
from CGAN.CGAN import CGAN
from Config.Config import MnistConfig

def run_main():
    """
    这是主函数
    """
    weight_path = os.path.abspath("./model/20191009134644/Epoch1378dcgan_loss1.5952800512313843discriminator_loss[0.49839333 0.7379193 ]acc[0.49839333 0.7379193 ].h5")
    result_path = os.path.abspath("./test_result")
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    cfg =  MnistConfig()
    cgan = CGAN(cfg,weight_path)
    cgan.save_image(0,result_path)


if __name__ == '__main__':
    run_main()

Guess you like

Origin blog.csdn.net/qq_30091945/article/details/102962215