【ディープラーニング】実験17 GANを利用した手書き数字サンプルの生成

GAN を使用して手書き数字サンプルを生成する

敵対的生成ネットワーク

GAN (Generative Adversarial Networks) は、深層生成ネットワーク (Generator) と深層識別ネットワーク (Discriminator) で構成される深層学習モデル アーキテクチャであり、敵対的学習を使用してトレーニングされます。GAN は、2014 年にイアン グッドフェローによって最初に提案されました。その提案以来、学術界や産業界から広範な注目と研究を受けてきました。

GAN の主なアイデアは、生成ネットワークにノイズからサンプルを生成させ、判別ネットワークを使用して生成されたサンプルと実際のデータの類似性を評価させることです。生成ネットワークはノイズ入力を使用してサンプルを生成し、弁別ネットワークは入力サンプルに基づいてサンプルが実際のデータであるかどうかを判断します。生成ネットワークと識別ネットワークは、敵対的学習を通じて互いに学習し、改善します。トレーニングプロセス中、生成ネットワークは生成されたサンプルが識別ネットワークを欺くことができることを望み、識別ネットワークは実際のデータと生成されたデータを区別して、サンプルの品質を向上させるという目的を達成することを望んでいます。

GAN は、主に画像生成、ビデオ生成、自然言語処理などの分野で幅広い用途があります。画像生成に関しては、GAN を使用して人間のアバター、動物、食べ物など、さまざまなスタイルの画像を生成できます。ビデオ生成に関しては、GAN は人間の動作や自然の風景などを含むリアルなビデオ シーケンスを生成できます。自然言語処理の観点から見ると、GAN はリアルな会話や記事などを生成できます。

GAN のトレーニング プロセスは、他の深層学習モデルよりも複雑です。生成ネットワークと識別ネットワークは、生成ネットワークによって生成されたサンプルが識別ネットワークを欺くことができるようにバランスを保つ必要があると同時に、生成されたサンプルが本物であるかどうかを識別するための識別ネットワーク自体の精度も維持する必要があります。 。GAN のトレーニング プロセスは、不安定なトレーニングやモードの崩壊などの問題を起こしやすいため、使用中に特定の調整と最適化を行う必要があります。

GAN 開発の歴史には、Conditional GAN (CGAN)、CycleGAN、Pix2Pix など、一連のバリアント モデルが登場してきました。これらのバリアント モデルはアプリケーション シナリオが異なりますが、中心となるアイデアはすべて GAN に基づいて調整および改善されています。

GAN は学界と産業界の両方で広範な注目と研究を集めており、多くの実用的なアプリケーションで GAN に対する高い需要があります。同時に、GAN 研究は、GAN の安定性、サンプルの多様性など、一連の問題や課題にも直面しています。GAN は今後も広く注目され、将来の開発に応用されることが予想されます。

プログラミング

# 导入相关库
from __future__ import print_function, division 

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt
import numpy as np
import sys
import os
class GAN():
    def __init__(self):
        # 行28,列28,也就是mnist的shape
        # 通道为1,灰度图
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        
        # 28*28*1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        
        # adam优化器
        optimizer = Adam(0.0002, 0.5)
        
        # 构造一个判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        
        # 构造一个生成器
        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        
        # 在训练generator的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    # 定义生成器
    def build_generator(self):
        
        model = Sequential()
        
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
                  
        #全连接层,28*28*1个神经元
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        #变成图片的形状
        model.add(Reshape(self.img_shape))
    
        noise = Input(shape=(self.latent_dim,))
                  
        #建立了从输入100维随机向量到28,28,1大小的图片生成模型
        img = model(noise)
                  
        return Model(noise, img)
    
    # 定义判别器
    def build_discriminator(self):
        
        model = Sequential()
        
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))     
        
        # 判别真伪
        model.add(Dense(1, activation='sigmoid'))
        
        img = Input(shape=self.img_shape)
        validity = model(img)
        
        return Model(img, validity)

    # 定义训练函数
    def train(self, epochs, batch_size=128, sample_interval=50):
        
        # 获取数据
        (X_train, _), (_,_) = mnist.load_data()
        
        # 进行标准化
        # 将图片像素值映射到-1到1
        X_train = X_train / 127.5 - 1
        X_train = np.expand_dims(X_train, axis=3)
       
        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        # 先训练判别器,再训练生成器
        for epoch in range(epochs):
            # 随机选取batch_size个图片
            # 对discriminator进行训练
            # 从train训练集里面随机找出batch—size大小(这么多个)的索引值
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            # 取出一个batch大小的图片
            imgs = X_train[idx] 
            
            # 正态分布生成batch_size个100维向量作为输入
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 用生成model的predict方法(model内部方法)将输入进行生成输出
            gen_imgs = self.generator.predict(noise)
            
            # 输入真实图片和标签全1》》到判别model,》》计算判别模型的loss
            d_loss_real = self.discriminator.train_on_batch(imgs, valid) 
            # 输入假的图片和标签全0》》到判别model,》计算判别模型的loss 
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) 
            # 将两者损失结合作为总损失
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 
            
            # 训练generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加
            g_loss = self.combined.train_on_batch(noise, valid)
            # D准确度越高,代表G生成的图片越离谱,准确率为0.5左右就可以以假乱真了
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))        
            # 每sample_interval轮生成一个图片
            if epoch % sample_interval == 0 :
                self.sample_images(epoch)
    # 定义生成图片函数
    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=10000, batch_size=256, sample_interval=200)
   Using TensorFlow backend.
   

   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
   Instructions for updating:
   Use tf.where in 2.0, which has the same broadcast rule as np.where
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
   
   

   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   

   0 [D loss: 0.986130, acc.: 26.17%] [G loss: 0.834596]
   

   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   

   1 [D loss: 0.403944, acc.: 83.98%] [G loss: 0.796327]
   2 [D loss: 0.347891, acc.: 83.01%] [G loss: 0.777482]
   3 [D loss: 0.344294, acc.: 81.45%] [G loss: 0.784850]
   4 [D loss: 0.340509, acc.: 82.42%] [G loss: 0.815181]
   5 [D loss: 0.323516, acc.: 86.52%] [G loss: 0.901500]
   6 [D loss: 0.292972, acc.: 93.75%] [G loss: 0.991136]
   7 [D loss: 0.257421, acc.: 97.27%] [G loss: 1.111775]
   8 [D loss: 0.231006, acc.: 98.05%] [G loss: 1.239357]
   9 [D loss: 0.194000, acc.: 99.80%] [G loss: 1.371341]
   10 [D loss: 0.173448, acc.: 100.00%] [G loss: 1.501673]
   11 [D loss: 0.154554, acc.: 100.00%] [G loss: 1.620853]
   12 [D loss: 0.142011, acc.: 99.61%] [G loss: 1.732671]
   13 [D loss: 0.124580, acc.: 99.80%] [G loss: 1.827322]
   14 [D loss: 0.116470, acc.: 99.80%] [G loss: 1.972561]
   15 [D loss: 0.105582, acc.: 100.00%] [G loss: 2.067226]
   16 [D loss: 0.093254, acc.: 100.00%] [G loss: 2.198446]
   17 [D loss: 0.087950, acc.: 100.00%] [G loss: 2.304677]
   18 [D loss: 0.073583, acc.: 100.00%] [G loss: 2.355863]
   19 [D loss: 0.072164, acc.: 100.00%] [G loss: 2.464585]
   20 [D loss: 0.065558, acc.: 99.80%] [G loss: 2.534361]
   21 [D loss: 0.059140, acc.: 100.00%] [G loss: 2.626909]
   22 [D loss: 0.057848, acc.: 100.00%] [G loss: 2.673893]
   23 [D loss: 0.052325, acc.: 100.00%] [G loss: 2.714813]
   24 [D loss: 0.052922, acc.: 100.00%] [G loss: 2.763450]
   25 [D loss: 0.046035, acc.: 100.00%] [G loss: 2.853940]
   26 [D loss: 0.049457, acc.: 100.00%] [G loss: 2.869173]
   27 [D loss: 0.042687, acc.: 100.00%] [G loss: 2.941574]
   28 [D loss: 0.039089, acc.: 100.00%] [G loss: 2.948203]
   29 [D loss: 0.036347, acc.: 100.00%] [G loss: 2.968413]
   30 [D loss: 0.038200, acc.: 100.00%] [G loss: 3.048651]
   31 [D loss: 0.039299, acc.: 100.00%] [G loss: 3.102673]
   32 [D loss: 0.033043, acc.: 100.00%] [G loss: 3.050264]
   33 [D loss: 0.035250, acc.: 100.00%] [G loss: 3.078978]
   34 [D loss: 0.037255, acc.: 100.00%] [G loss: 3.131599]
   35 [D loss: 0.033308, acc.: 100.00%] [G loss: 3.127816]
   36 [D loss: 0.035622, acc.: 100.00%] [G loss: 3.157865]
   37 [D loss: 0.038046, acc.: 100.00%] [G loss: 3.272691]
   38 [D loss: 0.037665, acc.: 100.00%] [G loss: 3.304567]
   39 [D loss: 0.029662, acc.: 100.00%] [G loss: 3.323656]
   40 [D loss: 0.031073, acc.: 100.00%] [G loss: 3.342812]
   41 [D loss: 0.031860, acc.: 100.00%] [G loss: 3.330144]
   42 [D loss: 0.033744, acc.: 100.00%] [G loss: 3.365006]
   43 [D loss: 0.030133, acc.: 100.00%] [G loss: 3.361420]
   44 [D loss: 0.032508, acc.: 100.00%] [G loss: 3.456270]
   45 [D loss: 0.030021, acc.: 100.00%] [G loss: 3.498577]
   46 [D loss: 0.029159, acc.: 100.00%] [G loss: 3.499414]
   47 [D loss: 0.031974, acc.: 100.00%] [G loss: 3.484164]
   48 [D loss: 0.033442, acc.: 99.80%] [G loss: 3.459633]
   49 [D loss: 0.030912, acc.: 100.00%] [G loss: 3.481130]
   50 [D loss: 0.033645, acc.: 100.00%] [G loss: 3.492231]
   51 [D loss: 0.034441, acc.: 100.00%] [G loss: 3.489124]
   52 [D loss: 0.034330, acc.: 100.00%] [G loss: 3.506902]
   53 [D loss: 0.034518, acc.: 100.00%] [G loss: 3.520910]
   54 [D loss: 0.030822, acc.: 100.00%] [G loss: 3.618950]
   55 [D loss: 0.034566, acc.: 99.80%] [G loss: 3.538144]
   56 [D loss: 0.032794, acc.: 100.00%] [G loss: 3.566177]
   57 [D loss: 0.037374, acc.: 99.61%] [G loss: 3.600816]
   58 [D loss: 0.037127, acc.: 100.00%] [G loss: 3.521185]
   59 [D loss: 0.039322, acc.: 100.00%] [G loss: 3.531039]
   60 [D loss: 0.030453, acc.: 100.00%] [G loss: 3.616879]
   61 [D loss: 0.044332, acc.: 99.02%] [G loss: 3.628755]
   62 [D loss: 0.037772, acc.: 99.80%] [G loss: 3.723062]
   63 [D loss: 0.041130, acc.: 99.61%] [G loss: 3.533709]
   64 [D loss: 0.044611, acc.: 99.41%] [G loss: 3.657721]
   65 [D loss: 0.037362, acc.: 99.61%] [G loss: 3.582735]
   66 [D loss: 0.050663, acc.: 99.02%] [G loss: 3.555587]
   67 [D loss: 0.039863, acc.: 99.41%] [G loss: 3.611456]
   68 [D loss: 0.051172, acc.: 99.02%] [G loss: 3.540278]
   69 [D loss: 0.052263, acc.: 98.63%] [G loss: 3.612799]
   70 [D loss: 0.056154, acc.: 99.41%] [G loss: 3.557292]
   71 [D loss: 0.055386, acc.: 99.22%] [G loss: 3.744767]
   72 [D loss: 0.096904, acc.: 97.66%] [G loss: 3.443518]
   73 [D loss: 0.070626, acc.: 98.05%] [G loss: 3.833835]
   74 [D loss: 0.180408, acc.: 93.55%] [G loss: 3.301687]
   75 [D loss: 0.074523, acc.: 98.44%] [G loss: 3.776305]
   76 [D loss: 0.057483, acc.: 99.02%] [G loss: 3.714150]
   77 [D loss: 0.141995, acc.: 95.12%] [G loss: 3.380850]
   78 [D loss: 0.067733, acc.: 98.63%] [G loss: 3.779586]
   79 [D loss: 0.303615, acc.: 87.89%] [G loss: 2.848376]
   80 [D loss: 0.145237, acc.: 94.14%] [G loss: 3.108039]
   81 [D loss: 0.046822, acc.: 99.22%] [G loss: 3.635069]
   82 [D loss: 0.108516, acc.: 96.48%] [G loss: 3.235212]
   83 [D loss: 0.105234, acc.: 96.48%] [G loss: 3.336948]
   84 [D loss: 0.233112, acc.: 90.82%] [G loss: 2.740180]
   85 [D loss: 0.118313, acc.: 94.92%] [G loss: 3.181991]
   86 [D loss: 0.300344, acc.: 87.30%] [G loss: 2.879515]
   87 [D loss: 0.106900, acc.: 96.48%] [G loss: 3.189476]
   88 [D loss: 0.381278, acc.: 84.38%] [G loss: 2.337953]
   89 [D loss: 0.252046, acc.: 88.28%] [G loss: 2.707138]
   90 [D loss: 0.087314, acc.: 97.07%] [G loss: 3.401120]
   91 [D loss: 0.260525, acc.: 90.62%] [G loss: 2.520348]
   92 [D loss: 0.148098, acc.: 93.36%] [G loss: 2.991073]
   93 [D loss: 0.141315, acc.: 96.09%] [G loss: 2.805464]
   94 [D loss: 0.288812, acc.: 89.45%] [G loss: 2.549888]
   95 [D loss: 0.143633, acc.: 94.14%] [G loss: 2.978777]
   96 [D loss: 0.584615, acc.: 78.32%] [G loss: 2.050247]
   97 [D loss: 0.328917, acc.: 83.01%] [G loss: 2.579935]
   98 [D loss: 0.111224, acc.: 97.66%] [G loss: 3.526271]
   99 [D loss: 0.702403, acc.: 68.95%] [G loss: 1.994847]
   100 [D loss: 0.335197, acc.: 84.96%] [G loss: 2.110721]
   101 [D loss: 0.147330, acc.: 93.55%] [G loss: 2.962312]
   102 [D loss: 0.091300, acc.: 98.44%] [G loss: 3.025173]
   103 [D loss: 0.304929, acc.: 87.70%] [G loss: 2.458197]
   104 [D loss: 0.199925, acc.: 90.43%] [G loss: 2.897576]
   105 [D loss: 0.335472, acc.: 87.30%] [G loss: 2.198746]
   106 [D loss: 0.235486, acc.: 88.09%] [G loss: 2.742341]
   107 [D loss: 0.346595, acc.: 84.77%] [G loss: 2.340909]
   108 [D loss: 0.211129, acc.: 91.60%] [G loss: 2.801579]
   109 [D loss: 0.361250, acc.: 84.96%] [G loss: 2.304583]
   110 [D loss: 0.183040, acc.: 93.16%] [G loss: 2.763792]
   111 [D loss: 0.365892, acc.: 82.62%] [G loss: 2.418060]
   112 [D loss: 0.197837, acc.: 92.19%] [G loss: 2.826400]
   113 [D loss: 0.413041, acc.: 81.05%] [G loss: 2.408184]
   114 [D loss: 0.198854, acc.: 91.80%] [G loss: 2.784730]
   115 [D loss: 0.395174, acc.: 81.45%] [G loss: 2.115457]
   116 [D loss: 0.189158, acc.: 90.04%] [G loss: 2.603389]
   117 [D loss: 0.237316, acc.: 92.97%] [G loss: 2.648600]
   118 [D loss: 0.285941, acc.: 87.89%] [G loss: 2.370326]
   119 [D loss: 0.208490, acc.: 90.43%] [G loss: 2.849175]
   120 [D loss: 0.454702, acc.: 80.08%] [G loss: 1.897220]
   121 [D loss: 0.217595, acc.: 89.06%] [G loss: 2.498424]
   122 [D loss: 0.173055, acc.: 94.92%] [G loss: 2.664538]
   123 [D loss: 0.262918, acc.: 90.82%] [G loss: 2.133595]
   124 [D loss: 0.190525, acc.: 91.02%] [G loss: 2.840866]
   125 [D loss: 0.292295, acc.: 87.11%] [G loss: 2.199357]
   126 [D loss: 0.215348, acc.: 88.87%] [G loss: 2.739654]
   127 [D loss: 0.365445, acc.: 84.96%] [G loss: 2.162226]
   128 [D loss: 0.200284, acc.: 89.65%] [G loss: 2.871504]
   129 [D loss: 0.450811, acc.: 79.10%] [G loss: 1.971582]
   130 [D loss: 0.200712, acc.: 90.82%] [G loss: 2.715580]
   131 [D loss: 0.310609, acc.: 85.94%] [G loss: 2.443402]
   132 [D loss: 0.234690, acc.: 89.65%] [G loss: 2.654381]
   133 [D loss: 0.449007, acc.: 79.30%] [G loss: 1.873044]
   134 [D loss: 0.233484, acc.: 87.89%] [G loss: 2.710910]
   135 [D loss: 0.274398, acc.: 87.70%] [G loss: 2.632379]
   136 [D loss: 0.295981, acc.: 87.11%] [G loss: 2.511465]
   137 [D loss: 0.247948, acc.: 89.65%] [G loss: 2.698283]
   138 [D loss: 0.490601, acc.: 75.20%] [G loss: 2.161157]
   139 [D loss: 0.215320, acc.: 90.43%] [G loss: 2.841792]
   140 [D loss: 0.564996, acc.: 73.63%] [G loss: 1.618642]
   141 [D loss: 0.270847, acc.: 86.52%] [G loss: 2.598266]
   142 [D loss: 0.210049, acc.: 93.75%] [G loss: 3.058943]
   143 [D loss: 0.462835, acc.: 76.17%] [G loss: 2.219012]
   144 [D loss: 0.213740, acc.: 89.84%] [G loss: 2.845963]
   145 [D loss: 0.518464, acc.: 73.63%] [G loss: 1.735387]
   146 [D loss: 0.273846, acc.: 87.11%] [G loss: 2.634973]
   ……

添付ファイル: 一連の記事

シリアルナンバー 記事ディレクトリ 直接リンク
1 ボストンの住宅価格予測 https://want595.blog.csdn.net/article/details/132181950
2 アヤメのデータセット分析 https://want595.blog.csdn.net/article/details/132182057
3 特徴処理 https://want595.blog.csdn.net/article/details/132182165
4 相互検証 https://want595.blog.csdn.net/article/details/132182238
5 ニューラルネットワークの構築例 https://want595.blog.csdn.net/article/details/132182341
6 TensorFlow を使用した完全な線形回帰 https://want595.blog.csdn.net/article/details/132182417
7 TensorFlow を使用した完全なロジスティック回帰 https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard のケース https://want595.blog.csdn.net/article/details/132182584
9 Keras を使用した完全な線形回帰 https://want595.blog.csdn.net/article/details/132182723
10 Keras を使用した完全なロジスティック回帰 https://want595.blog.csdn.net/article/details/132182795
11 Keras の事前トレーニング済みモデルを使用した完全な猫と犬の認識 https://want595.blog.csdn.net/article/details/132243928
12 PyTorch を使用したモデルのトレーニング https://want595.blog.csdn.net/article/details/132243989
13 ドロップアウトを使用してオーバーフィッティングを抑制する https://want595.blog.csdn.net/article/details/132244111
14 CNN を使用して MNIST 手書き認識を完了する (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 CNN を使用して MNIST 手書き認識を完了する (Keras) https://want595.blog.csdn.net/article/details/132244552
16 CNN を使用して MNIST 手書き認識を完了する (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 GAN を使用して手書き数字サンプルを生成する https://want595.blog.csdn.net/article/details/132244764
18 自然言語処理 https://want595.blog.csdn.net/article/details/132276591

おすすめ

転載: blog.csdn.net/m0_68111267/article/details/132244764