[Deep Learning] Experiment 17 Using GAN to generate handwritten digit samples

Using GAN to generate handwritten digit samples

Generative Adversarial Network

GAN (Generative Adversarial Networks) is a deep learning model architecture, consisting of a deep generation network (Generator) and a deep discriminator network (Discriminator), and is trained using adversarial learning. GAN was originally proposed by Ian Goodfellow in 2014. Since its proposal, it has received extensive attention and research from academia and industry.

The main idea of ​​GAN is to let the generative network generate samples from noise, and use the discriminant network to evaluate the similarity between the generated samples and real data. The generation network uses noise input to generate samples, and the discriminator network gives a judgment based on the input sample to determine whether the sample is real data. The generative network and the discriminative network learn from and improve each other through adversarial learning. During the training process, the generating network hopes that the generated samples can deceive the discriminating network, and the discriminating network hopes to distinguish between real data and generated data, thereby achieving the purpose of improving sample quality.

GAN has a wide range of applications, mainly including image generation, video generation, natural language processing and other fields. In terms of image generation, GAN can be used to generate various styles of pictures, such as human avatars, animals, food, etc. In terms of video generation, GAN can generate realistic video sequences, including human actions, natural scenery, etc. In terms of natural language processing, GAN can generate realistic conversations, articles, etc.

The training process of GAN is more complicated than other deep learning models. The generation network and the identification network need to maintain a balance so that the samples generated by the generation network can deceive the identification network. At the same time, the identification network also needs to maintain its own accuracy to determine whether the generated samples are authentic. Since the training process of GAN is prone to problems such as unstable training and mode collapse, certain adjustments and optimizations need to be made during use.

A series of variant models have emerged in the history of GAN development, such as Conditional GAN ​​(CGAN), CycleGAN, Pix2Pix, etc. These variant models differ in application scenarios, but the core ideas are all adjusted and improved based on GAN.

GAN has received extensive attention and research in both academia and industry, and many practical applications have high demand for GAN. At the same time, GAN research also faces a series of problems and challenges, such as GAN stability, sample diversity, etc. It is foreseeable that GAN will continue to receive widespread attention and application in future development.

programming

# 导入相关库
from __future__ import print_function, division 

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt
import numpy as np
import sys
import os
class GAN():
    def __init__(self):
        # 行28,列28,也就是mnist的shape
        # 通道为1,灰度图
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        
        # 28*28*1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        
        # adam优化器
        optimizer = Adam(0.0002, 0.5)
        
        # 构造一个判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        
        # 构造一个生成器
        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        
        # 在训练generator的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    # 定义生成器
    def build_generator(self):
        
        model = Sequential()
        
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
                  
        #全连接层,28*28*1个神经元
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        #变成图片的形状
        model.add(Reshape(self.img_shape))
    
        noise = Input(shape=(self.latent_dim,))
                  
        #建立了从输入100维随机向量到28,28,1大小的图片生成模型
        img = model(noise)
                  
        return Model(noise, img)
    
    # 定义判别器
    def build_discriminator(self):
        
        model = Sequential()
        
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))     
        
        # 判别真伪
        model.add(Dense(1, activation='sigmoid'))
        
        img = Input(shape=self.img_shape)
        validity = model(img)
        
        return Model(img, validity)

    # 定义训练函数
    def train(self, epochs, batch_size=128, sample_interval=50):
        
        # 获取数据
        (X_train, _), (_,_) = mnist.load_data()
        
        # 进行标准化
        # 将图片像素值映射到-1到1
        X_train = X_train / 127.5 - 1
        X_train = np.expand_dims(X_train, axis=3)
       
        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        # 先训练判别器,再训练生成器
        for epoch in range(epochs):
            # 随机选取batch_size个图片
            # 对discriminator进行训练
            # 从train训练集里面随机找出batch—size大小(这么多个)的索引值
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            # 取出一个batch大小的图片
            imgs = X_train[idx] 
            
            # 正态分布生成batch_size个100维向量作为输入
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 用生成model的predict方法(model内部方法)将输入进行生成输出
            gen_imgs = self.generator.predict(noise)
            
            # 输入真实图片和标签全1》》到判别model,》》计算判别模型的loss
            d_loss_real = self.discriminator.train_on_batch(imgs, valid) 
            # 输入假的图片和标签全0》》到判别model,》计算判别模型的loss 
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) 
            # 将两者损失结合作为总损失
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 
            
            # 训练generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加
            g_loss = self.combined.train_on_batch(noise, valid)
            # D准确度越高,代表G生成的图片越离谱,准确率为0.5左右就可以以假乱真了
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))        
            # 每sample_interval轮生成一个图片
            if epoch % sample_interval == 0 :
                self.sample_images(epoch)
    # 定义生成图片函数
    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=10000, batch_size=256, sample_interval=200)
   Using TensorFlow backend.
   

   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
   Instructions for updating:
   Use tf.where in 2.0, which has the same broadcast rule as np.where
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
   
   

   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   

   0 [D loss: 0.986130, acc.: 26.17%] [G loss: 0.834596]
   

   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   

   1 [D loss: 0.403944, acc.: 83.98%] [G loss: 0.796327]
   2 [D loss: 0.347891, acc.: 83.01%] [G loss: 0.777482]
   3 [D loss: 0.344294, acc.: 81.45%] [G loss: 0.784850]
   4 [D loss: 0.340509, acc.: 82.42%] [G loss: 0.815181]
   5 [D loss: 0.323516, acc.: 86.52%] [G loss: 0.901500]
   6 [D loss: 0.292972, acc.: 93.75%] [G loss: 0.991136]
   7 [D loss: 0.257421, acc.: 97.27%] [G loss: 1.111775]
   8 [D loss: 0.231006, acc.: 98.05%] [G loss: 1.239357]
   9 [D loss: 0.194000, acc.: 99.80%] [G loss: 1.371341]
   10 [D loss: 0.173448, acc.: 100.00%] [G loss: 1.501673]
   11 [D loss: 0.154554, acc.: 100.00%] [G loss: 1.620853]
   12 [D loss: 0.142011, acc.: 99.61%] [G loss: 1.732671]
   13 [D loss: 0.124580, acc.: 99.80%] [G loss: 1.827322]
   14 [D loss: 0.116470, acc.: 99.80%] [G loss: 1.972561]
   15 [D loss: 0.105582, acc.: 100.00%] [G loss: 2.067226]
   16 [D loss: 0.093254, acc.: 100.00%] [G loss: 2.198446]
   17 [D loss: 0.087950, acc.: 100.00%] [G loss: 2.304677]
   18 [D loss: 0.073583, acc.: 100.00%] [G loss: 2.355863]
   19 [D loss: 0.072164, acc.: 100.00%] [G loss: 2.464585]
   20 [D loss: 0.065558, acc.: 99.80%] [G loss: 2.534361]
   21 [D loss: 0.059140, acc.: 100.00%] [G loss: 2.626909]
   22 [D loss: 0.057848, acc.: 100.00%] [G loss: 2.673893]
   23 [D loss: 0.052325, acc.: 100.00%] [G loss: 2.714813]
   24 [D loss: 0.052922, acc.: 100.00%] [G loss: 2.763450]
   25 [D loss: 0.046035, acc.: 100.00%] [G loss: 2.853940]
   26 [D loss: 0.049457, acc.: 100.00%] [G loss: 2.869173]
   27 [D loss: 0.042687, acc.: 100.00%] [G loss: 2.941574]
   28 [D loss: 0.039089, acc.: 100.00%] [G loss: 2.948203]
   29 [D loss: 0.036347, acc.: 100.00%] [G loss: 2.968413]
   30 [D loss: 0.038200, acc.: 100.00%] [G loss: 3.048651]
   31 [D loss: 0.039299, acc.: 100.00%] [G loss: 3.102673]
   32 [D loss: 0.033043, acc.: 100.00%] [G loss: 3.050264]
   33 [D loss: 0.035250, acc.: 100.00%] [G loss: 3.078978]
   34 [D loss: 0.037255, acc.: 100.00%] [G loss: 3.131599]
   35 [D loss: 0.033308, acc.: 100.00%] [G loss: 3.127816]
   36 [D loss: 0.035622, acc.: 100.00%] [G loss: 3.157865]
   37 [D loss: 0.038046, acc.: 100.00%] [G loss: 3.272691]
   38 [D loss: 0.037665, acc.: 100.00%] [G loss: 3.304567]
   39 [D loss: 0.029662, acc.: 100.00%] [G loss: 3.323656]
   40 [D loss: 0.031073, acc.: 100.00%] [G loss: 3.342812]
   41 [D loss: 0.031860, acc.: 100.00%] [G loss: 3.330144]
   42 [D loss: 0.033744, acc.: 100.00%] [G loss: 3.365006]
   43 [D loss: 0.030133, acc.: 100.00%] [G loss: 3.361420]
   44 [D loss: 0.032508, acc.: 100.00%] [G loss: 3.456270]
   45 [D loss: 0.030021, acc.: 100.00%] [G loss: 3.498577]
   46 [D loss: 0.029159, acc.: 100.00%] [G loss: 3.499414]
   47 [D loss: 0.031974, acc.: 100.00%] [G loss: 3.484164]
   48 [D loss: 0.033442, acc.: 99.80%] [G loss: 3.459633]
   49 [D loss: 0.030912, acc.: 100.00%] [G loss: 3.481130]
   50 [D loss: 0.033645, acc.: 100.00%] [G loss: 3.492231]
   51 [D loss: 0.034441, acc.: 100.00%] [G loss: 3.489124]
   52 [D loss: 0.034330, acc.: 100.00%] [G loss: 3.506902]
   53 [D loss: 0.034518, acc.: 100.00%] [G loss: 3.520910]
   54 [D loss: 0.030822, acc.: 100.00%] [G loss: 3.618950]
   55 [D loss: 0.034566, acc.: 99.80%] [G loss: 3.538144]
   56 [D loss: 0.032794, acc.: 100.00%] [G loss: 3.566177]
   57 [D loss: 0.037374, acc.: 99.61%] [G loss: 3.600816]
   58 [D loss: 0.037127, acc.: 100.00%] [G loss: 3.521185]
   59 [D loss: 0.039322, acc.: 100.00%] [G loss: 3.531039]
   60 [D loss: 0.030453, acc.: 100.00%] [G loss: 3.616879]
   61 [D loss: 0.044332, acc.: 99.02%] [G loss: 3.628755]
   62 [D loss: 0.037772, acc.: 99.80%] [G loss: 3.723062]
   63 [D loss: 0.041130, acc.: 99.61%] [G loss: 3.533709]
   64 [D loss: 0.044611, acc.: 99.41%] [G loss: 3.657721]
   65 [D loss: 0.037362, acc.: 99.61%] [G loss: 3.582735]
   66 [D loss: 0.050663, acc.: 99.02%] [G loss: 3.555587]
   67 [D loss: 0.039863, acc.: 99.41%] [G loss: 3.611456]
   68 [D loss: 0.051172, acc.: 99.02%] [G loss: 3.540278]
   69 [D loss: 0.052263, acc.: 98.63%] [G loss: 3.612799]
   70 [D loss: 0.056154, acc.: 99.41%] [G loss: 3.557292]
   71 [D loss: 0.055386, acc.: 99.22%] [G loss: 3.744767]
   72 [D loss: 0.096904, acc.: 97.66%] [G loss: 3.443518]
   73 [D loss: 0.070626, acc.: 98.05%] [G loss: 3.833835]
   74 [D loss: 0.180408, acc.: 93.55%] [G loss: 3.301687]
   75 [D loss: 0.074523, acc.: 98.44%] [G loss: 3.776305]
   76 [D loss: 0.057483, acc.: 99.02%] [G loss: 3.714150]
   77 [D loss: 0.141995, acc.: 95.12%] [G loss: 3.380850]
   78 [D loss: 0.067733, acc.: 98.63%] [G loss: 3.779586]
   79 [D loss: 0.303615, acc.: 87.89%] [G loss: 2.848376]
   80 [D loss: 0.145237, acc.: 94.14%] [G loss: 3.108039]
   81 [D loss: 0.046822, acc.: 99.22%] [G loss: 3.635069]
   82 [D loss: 0.108516, acc.: 96.48%] [G loss: 3.235212]
   83 [D loss: 0.105234, acc.: 96.48%] [G loss: 3.336948]
   84 [D loss: 0.233112, acc.: 90.82%] [G loss: 2.740180]
   85 [D loss: 0.118313, acc.: 94.92%] [G loss: 3.181991]
   86 [D loss: 0.300344, acc.: 87.30%] [G loss: 2.879515]
   87 [D loss: 0.106900, acc.: 96.48%] [G loss: 3.189476]
   88 [D loss: 0.381278, acc.: 84.38%] [G loss: 2.337953]
   89 [D loss: 0.252046, acc.: 88.28%] [G loss: 2.707138]
   90 [D loss: 0.087314, acc.: 97.07%] [G loss: 3.401120]
   91 [D loss: 0.260525, acc.: 90.62%] [G loss: 2.520348]
   92 [D loss: 0.148098, acc.: 93.36%] [G loss: 2.991073]
   93 [D loss: 0.141315, acc.: 96.09%] [G loss: 2.805464]
   94 [D loss: 0.288812, acc.: 89.45%] [G loss: 2.549888]
   95 [D loss: 0.143633, acc.: 94.14%] [G loss: 2.978777]
   96 [D loss: 0.584615, acc.: 78.32%] [G loss: 2.050247]
   97 [D loss: 0.328917, acc.: 83.01%] [G loss: 2.579935]
   98 [D loss: 0.111224, acc.: 97.66%] [G loss: 3.526271]
   99 [D loss: 0.702403, acc.: 68.95%] [G loss: 1.994847]
   100 [D loss: 0.335197, acc.: 84.96%] [G loss: 2.110721]
   101 [D loss: 0.147330, acc.: 93.55%] [G loss: 2.962312]
   102 [D loss: 0.091300, acc.: 98.44%] [G loss: 3.025173]
   103 [D loss: 0.304929, acc.: 87.70%] [G loss: 2.458197]
   104 [D loss: 0.199925, acc.: 90.43%] [G loss: 2.897576]
   105 [D loss: 0.335472, acc.: 87.30%] [G loss: 2.198746]
   106 [D loss: 0.235486, acc.: 88.09%] [G loss: 2.742341]
   107 [D loss: 0.346595, acc.: 84.77%] [G loss: 2.340909]
   108 [D loss: 0.211129, acc.: 91.60%] [G loss: 2.801579]
   109 [D loss: 0.361250, acc.: 84.96%] [G loss: 2.304583]
   110 [D loss: 0.183040, acc.: 93.16%] [G loss: 2.763792]
   111 [D loss: 0.365892, acc.: 82.62%] [G loss: 2.418060]
   112 [D loss: 0.197837, acc.: 92.19%] [G loss: 2.826400]
   113 [D loss: 0.413041, acc.: 81.05%] [G loss: 2.408184]
   114 [D loss: 0.198854, acc.: 91.80%] [G loss: 2.784730]
   115 [D loss: 0.395174, acc.: 81.45%] [G loss: 2.115457]
   116 [D loss: 0.189158, acc.: 90.04%] [G loss: 2.603389]
   117 [D loss: 0.237316, acc.: 92.97%] [G loss: 2.648600]
   118 [D loss: 0.285941, acc.: 87.89%] [G loss: 2.370326]
   119 [D loss: 0.208490, acc.: 90.43%] [G loss: 2.849175]
   120 [D loss: 0.454702, acc.: 80.08%] [G loss: 1.897220]
   121 [D loss: 0.217595, acc.: 89.06%] [G loss: 2.498424]
   122 [D loss: 0.173055, acc.: 94.92%] [G loss: 2.664538]
   123 [D loss: 0.262918, acc.: 90.82%] [G loss: 2.133595]
   124 [D loss: 0.190525, acc.: 91.02%] [G loss: 2.840866]
   125 [D loss: 0.292295, acc.: 87.11%] [G loss: 2.199357]
   126 [D loss: 0.215348, acc.: 88.87%] [G loss: 2.739654]
   127 [D loss: 0.365445, acc.: 84.96%] [G loss: 2.162226]
   128 [D loss: 0.200284, acc.: 89.65%] [G loss: 2.871504]
   129 [D loss: 0.450811, acc.: 79.10%] [G loss: 1.971582]
   130 [D loss: 0.200712, acc.: 90.82%] [G loss: 2.715580]
   131 [D loss: 0.310609, acc.: 85.94%] [G loss: 2.443402]
   132 [D loss: 0.234690, acc.: 89.65%] [G loss: 2.654381]
   133 [D loss: 0.449007, acc.: 79.30%] [G loss: 1.873044]
   134 [D loss: 0.233484, acc.: 87.89%] [G loss: 2.710910]
   135 [D loss: 0.274398, acc.: 87.70%] [G loss: 2.632379]
   136 [D loss: 0.295981, acc.: 87.11%] [G loss: 2.511465]
   137 [D loss: 0.247948, acc.: 89.65%] [G loss: 2.698283]
   138 [D loss: 0.490601, acc.: 75.20%] [G loss: 2.161157]
   139 [D loss: 0.215320, acc.: 90.43%] [G loss: 2.841792]
   140 [D loss: 0.564996, acc.: 73.63%] [G loss: 1.618642]
   141 [D loss: 0.270847, acc.: 86.52%] [G loss: 2.598266]
   142 [D loss: 0.210049, acc.: 93.75%] [G loss: 3.058943]
   143 [D loss: 0.462835, acc.: 76.17%] [G loss: 2.219012]
   144 [D loss: 0.213740, acc.: 89.84%] [G loss: 2.845963]
   145 [D loss: 0.518464, acc.: 73.63%] [G loss: 1.735387]
   146 [D loss: 0.273846, acc.: 87.11%] [G loss: 2.634973]
   ……

Attachment: series of articles

serial number Article directory direct link
1 Boston house price forecast https://want595.blog.csdn.net/article/details/132181950
2 Iris dataset analysis https://want595.blog.csdn.net/article/details/132182057
3 Feature processing https://want595.blog.csdn.net/article/details/132182165
4 Cross-validation https://want595.blog.csdn.net/article/details/132182238
5 Constructing a Neural Network Example https://want595.blog.csdn.net/article/details/132182341
6 Complete linear regression using TensorFlow https://want595.blog.csdn.net/article/details/132182417
7 Complete logistic regression using TensorFlow https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard case https://want595.blog.csdn.net/article/details/132182584
9 Complete linear regression using Keras https://want595.blog.csdn.net/article/details/132182723
10 Complete logistic regression using Keras https://want595.blog.csdn.net/article/details/132182795
11 Complete cat and dog recognition using Keras pre-trained model https://want595.blog.csdn.net/article/details/132243928
12 Training models using PyTorch https://want595.blog.csdn.net/article/details/132243989
13 Use Dropout to suppress overfitting https://want595.blog.csdn.net/article/details/132244111
14 Using CNN to complete MNIST handwriting recognition (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 Using CNN to complete MNIST handwriting recognition (Keras) https://want595.blog.csdn.net/article/details/132244552
16 Using CNN to complete MNIST handwriting recognition (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 Using GAN to generate handwritten digit samples https://want595.blog.csdn.net/article/details/132244764
18 natural language processing https://want595.blog.csdn.net/article/details/132276591

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132244764