Tensorflow2.0之DCGAN实现动漫图像生成网络

1、导入需要的库

from tensorflow.keras.models import Sequential
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt
import os
import math
import numpy as np
import IPython.display as display

2、导入数据集

PATH = '.\\faces/'
X_train = tf.data.Dataset.list_files(PATH+'*.jpg')

def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [image_size, image_size])
    image = image / 255.0

    return image

X_train = X_train.map(
    load, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().shuffle(
    SHUFFLE_SIZE).batch(BATCH_SIZE)

3、建立生成器

class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        filter = 64
        # 转置卷积层1,输出channel为filter*8,核大小4,步长1,不使用padding,不使用偏置
        self.conv1 = layers.Conv2DTranspose(filter*8, 4,1, 'valid', use_bias=False)
        self.bn1 = layers.BatchNormalization()
        # 转置卷积层2
        self.conv2 = layers.Conv2DTranspose(filter*4, 4,2, 'same', use_bias=False)
        self.bn2 = layers.BatchNormalization()
        # 转置卷积层3
        self.conv3 = layers.Conv2DTranspose(filter*2, 4,2, 'same', use_bias=False)
        self.bn3 = layers.BatchNormalization()
        # 转置卷积层4
        self.conv4 = layers.Conv2DTranspose(filter*1, 4,2, 'same', use_bias=False)
        self.bn4 = layers.BatchNormalization()
        # 转置卷积层5
        self.conv5 = layers.Conv2DTranspose(3, 4,2, 'same', use_bias=False)

    def call(self, inputs, training=None):
        x = inputs # [z, 100]
        # Reshape乘4D张量,方便后续转置卷积运算:(b, 1, 1, 100)
        x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
        x = tf.nn.relu(x) # 激活函数
        # 转置卷积-BN-激活函数:(b, 4, 4, 512)
        x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
        # 转置卷积-BN-激活函数:(b, 8, 8, 256)
        x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
        # 转置卷积-BN-激活函数:(b, 16, 16, 128)
        x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
        # 转置卷积-BN-激活函数:(b, 32, 32, 64)
        x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
        # 转置卷积-激活函数:(b, 64, 64, 3)
        x = self.conv5(x)
        x = tf.sigmoid(x) # 输出x范围0~1,与预处理一致

        return x

generator = Generator()

4、建立判别器

class Discriminator(tf.keras.Model):
    # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()
        filter = 64
        # 卷积层
        self.conv1 = layers.Conv2D(filter, 4, 2, 'valid', use_bias=False)
        self.bn1 = layers.BatchNormalization()
        # 卷积层
        self.conv2 = layers.Conv2D(filter*2, 4, 2, 'valid', use_bias=False)
        self.bn2 = layers.BatchNormalization()
        # 卷积层
        self.conv3 = layers.Conv2D(filter*4, 4, 2, 'valid', use_bias=False)
        self.bn3 = layers.BatchNormalization()
        # 卷积层
        self.conv4 = layers.Conv2D(filter*8, 3, 1, 'valid', use_bias=False)
        self.bn4 = layers.BatchNormalization()
        # 卷积层
        self.conv5 = layers.Conv2D(filter*16, 3, 1, 'valid', use_bias=False)
        self.bn5 = layers.BatchNormalization()
        # 全局池化层
        self.pool = layers.GlobalAveragePooling2D()
        # 特征打平
        self.flatten = layers.Flatten()
        # 2分类全连接层
        self.fc = layers.Dense(1)


    def call(self, inputs, training=None):
        # 卷积-BN-激活函数:(4, 31, 31, 64)
        x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training))
        # 卷积-BN-激活函数:(4, 14, 14, 128)
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        # 卷积-BN-激活函数:(4, 6, 6, 256)
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        # 卷积-BN-激活函数:(4, 4, 4, 512)
        x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))
        # 卷积-BN-激活函数:(4, 2, 2, 1024)
        x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
        # 卷积-BN-激活函数:(4, 1024)
        x = self.pool(x)
        # 打平
        x = self.flatten(x)
        # 输出,[b, 1024] => [b, 1]
        logits = self.fc(x)

        return logits

discriminator = Discriminator()

5、建立损失函数

5.1 生成器损失函数

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output):
    g_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    return g_loss

5.2 判别器损失函数

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    d_loss = real_loss + generated_loss
    return d_loss

6、初始化优化器

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

7、定义一次训练过程

def train_step(input_image):
    noise = tf.random.normal([BATCH_SIZE, latent_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(noise, training=True)

        disc_generated_output = discriminator(gen_output, training=True)
        disc_real_output = discriminator(input_image, training=True)

        g_loss = generator_loss(disc_generated_output)
        d_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(g_loss,
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(d_loss,
                                               discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
    return g_loss, d_loss

8、将生成的多张图像放到一个图里

def combine_images(images):
    num = images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1], 3),
                    dtype = np.float32)
    for index,img in enumerate(images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i+1) * shape[0], j * shape[1]:(j+1) * shape[1], 0:3] = img[:,:,:]
    image = image * 127.5 + 127.5
    plt.imshow(image.astype(np.uint8))
    plt.axis('off')
    plt.show()

9、训练

def train(BATCH_SIZE, X_train):
    # 生成图片的连接图片数
    generated_image_size = 36
    
    for epoch in range(1000):
        # 打印当前轮数
        print('Epoch is ',epoch)
        index = 0
        for input_image in X_train:
            index += 1
            g_loss, d_loss = train_step(input_image)
            print('.', end='')
            
            if index % 50 == 0:
                # 每50次输出一次图片
                noise_need = np.random.normal(0,1,size=(generated_image_size,latent_dim))
                generated_image_need = generator(noise_need, training=False)
                image = combine_images(generated_image_need)
                
            if index % 10 == 0:
                print('batch: %d, g_loss: %f, d_loss: %f' % (index, g_loss, d_loss))

train(BATCH_SIZE=128, X_train=X_train)
发布了117 篇原创文章 · 获赞 13 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_36758914/article/details/104878227