第二个深度学习项目 使用condition GAN 训练cifar-10数据集

版权声明:版权归世界上所有无产阶级所有 https://blog.csdn.net/qq_41776781/article/details/86671556

今天学习了加载cifar-10数据集,加上之前要做condition GAN, 发现暂时还没有使用tensorflow 用于训练cifar-10数据集的condition GAN, 所以就花了一晚上用自己的模型训练cifar-10数据集,等明天结果出来在公布代码

杂记:感觉condition GAN和其他GAN的特点

第一: 生成器的输入参数之后,不仅有图像信息 而且还包括图像的标签信息

第二:生成器第一层之中需要将图像信息和标签信息进行full_connected()操作, 然后进行反置卷积 relu激活函数 batch_norm等

第三:识别器的输入参数也要包括图像的标签信息, 

第四:图像信息在第一层之中和标签信息进行full_connected()操作, 然后进行卷积 relu激活函数 batch_norm等

第五:识别器只用来判断图像是否是真, 不用判断标签是否正确, 

第六:在以后将讲的ACGAN中识别器的输出信息

睡不着!!!!!!!!!!!!!!!!不喜欢寒假除了在家写程序,写论文什么也做不了!!!!!!!!!!!boring

import os
from imageio import  imsave
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import tensorflow.contrib.layers as tcl
import pickle

y_dim = 2
# 这段代码是官网上给的 但是读取的时候打印key的时候可以发现李米娜是由问题的 所以修改了编码方式
def unpickle(file):
    with open(file, 'rb') as f:
        cifar_dict = pickle.load(f, encoding='latin1')
    return cifar_dict

def load_cifar():
    # 定义用来存放图像数据 图像标签 图像名称list
    tem_cifar_image = []
    tem_cifar_label = []
    size = 64
    for i in range(1, 6):
        cifar_file = "../Dataset/cifar-10/data_batch_" + str(i)
        # print("确认一下目标文件是否正确",cifar_file)
        cifar = unpickle(cifar_file)
        # print("测试一下cifar的类型:", type(cifar), len(cifar['data']), cifar['batch_label'])
        cifar_label = cifar['labels']
        cifar_image = cifar['data']
        # 测试可知道cifar中存放四个信息 标签 图像的numpy数据  图像文件名称  batch_label
        # print("测试一下cifar_image[0], cifar_label[0], cifar_image_name[0]",cifar_image[0], cifar_label[0],cifar_image_name[0])
        # 使用transpose()函数是因为cifar存放的是图像标准是 通道数 高 宽 所以要修改成  高 宽 通道数
        cifar_image = cifar_image.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        # cifar_image = cv2.resize(cifar_image, (WIDTH, HEIGHT))
        cifar_image = cifar_image / 255.
        # 2019 1 27新加代码  适用于ACGAN
        for k_tmp in range(len(cifar_label)):
            if cifar_label[k_tmp] == 0:
                cifar_label[k_tmp] = [1, 0]
            else:
                cifar_label[k_tmp] = [0, 1]
        cifar_label = np.array(cifar_label)
        # print("*************cifar_label**************",cifar_label)

        tem_cifar_image.append(cifar_image)
        tem_cifar_label.append(cifar_label)

    # 之所以把数据进行连接一维是因为 不连接成维的话 后面将numpy转化成图像的时候不方便
    # 有兴趣的可以试一下 提示一下 不执行np.concatenate() tem_cifar_image_name的形状是:(5 10000)
    cifar_image = np.concatenate(tem_cifar_image)
    # cifar_image = np.asarray([scipy.misc.imresize(x_img, [size, size]) for x_img in cifar_image])
    cifar_label = np.concatenate(tem_cifar_label)
    return cifar_image, cifar_label


def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)


def generator(z, label, reuse=None):
    size = 4
    with tf.variable_scope('generator', reuse=reuse):
        z = tf.concat([z, label], axis=1)
        g = tcl.fully_connected(z, size * size * 1024, activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm)
        g = tf.reshape(g, (-1, size, size, 1024))  # size
        g = tcl.conv2d_transpose(g, 512, 3, stride=2,  # size*2
                                 activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME',
                                 weights_initializer=tf.random_normal_initializer(0, 0.02))
        g = tcl.conv2d_transpose(g, 256, 3, stride=2,  # size*4
                                 activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME',
                                 weights_initializer=tf.random_normal_initializer(0, 0.02))
        g = tcl.conv2d_transpose(g, 3, 3, stride=2,  # size*8
                                 activation_fn=tf.nn.sigmoid, normalizer_fn=tcl.batch_norm, padding='SAME',
                                 weights_initializer=tf.random_normal_initializer(0, 0.02))

        return g

def conv_cond_concat(x, y):
    """Concatenate conditioning vector on feature map axis."""
    x_shapes = x.get_shape()
    y_shapes = y.get_shape()
    return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)


def discriminator(image,label, reuse=False):
    df_dim = 64
    with tf.variable_scope('discriminator', reuse=reuse):
        print(image,label)
        label = tf.reshape(label, shape=[batch_size, 1, 1, y_dim])
        image = conv_cond_concat(image, label)
        print(image)

        shared = tcl.conv2d(image, num_outputs=df_dim, kernel_size=4,  # bzx64x64x3 -> bzx32x32x64
                            stride=2, activation_fn=lrelu)
        print("shared",shared)
        shared = tcl.conv2d(shared, num_outputs=df_dim * 2, kernel_size=4,  # 16x16x128
                            stride=2, activation_fn=lrelu, normalizer_fn=tcl.batch_norm)
        shared = tcl.conv2d(shared, num_outputs=df_dim * 4, kernel_size=4,  # 8x8x256
                            stride=2, activation_fn=lrelu, normalizer_fn=tcl.batch_norm)
        shared = tcl.conv2d(shared, num_outputs=df_dim * 8, kernel_size=4,  # 4x4x512
                            stride=2, activation_fn=lrelu, normalizer_fn=tcl.batch_norm)

        shared = tcl.flatten(shared)
        print("*******",shared)
        d = tcl.fully_connected(shared, 1, activation_fn=None,
                                weights_initializer=tf.random_normal_initializer(0, 0.02))

        return tf.nn.sigmoid(d), d


# 加载生成图像的噪音
def load_sample():
    sample = np.random.uniform(0.0, 1.0, [batch_size, z_dim]).astype(np.float32)
    z_samples = []
    y_samples = []
    for i in range(batch_size):
        if i < (64//2):
            z_samples.append(sample[i, :])
            y_samples.append([1, 0])
        else:
            z_samples.append(sample[i, :])
            y_samples.append([0, 1])
    return z_samples, y_samples


# 定义加载mini_batch数据
def get_random_batch(image_all, label_all):
    # 生成一个一维数组 并且每次打乱顺序取前batch_size个数据
    image_data = np.arange(image_all.shape[0])
    np.random.shuffle(image_data)
    image_data = image_data[:batch_size]
    x_batch = image_all[image_data, :, :, :]
    y_batch = label_all[image_data, :]
    return x_batch, y_batch


def imsave_image(images):
    size = [8,8]
    if isinstance(images, list):
        images = np.array(images)
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h: j * h + h, i * w: i * w + w, :] = image
    return img


def train():
    # 将celeba图像转化numpy的形式
    image, image_label = load_cifar()
    print("**********************",image_label)
    z_samples, y_samples = load_sample()
    loss_d_list = []
    loss_g_list = []

    # 定义真实图像的占位符
    real_image = tf.placeholder(dtype=tf.float32, shape=[batch_size, width, high, 3], name='X')
    y_label = tf.placeholder(dtype=tf.float32, shape=[batch_size, y_dim], name='Y')
    noise = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='noise')
    is_training = tf.placeholder(dtype=tf.bool, name='is_training')

    fake_image = generator(noise,y_label)
    print("fake_image",fake_image)
    d_real_logist, d_real = discriminator(real_image, y_label)
    d_fake_logist, d_fake = discriminator(fake_image, y_label, reuse=True)

    # loss
    D_fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake), logits=d_fake_logist))

    D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real), logits=d_real_logist))
    G_fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake), logits=d_fake_logist))

    loss_d = D_real_loss + D_fake_loss
    loss_g = G_fake_loss
    vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
    vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]

    optim_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
    optim_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)

    # 定义sess.run开始会话和允许GPU的自动分配
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    for i in range(50001):
        image_batch, image_label_batch = get_random_batch(image, image_label)
        n = np.random.uniform(0.0, 1.0, [batch_size, z_dim]).astype(np.float32)
        # 定义生成器和识别器的优化器 这里面选用的是AdamOptimizer
        _, d_ls = sess.run([optim_d, loss_d], feed_dict={real_image: image_batch, noise: n, y_label:image_label_batch, is_training: True})
        _, g_ls = sess.run([optim_g, loss_g], feed_dict={noise: n, y_label:image_label_batch, is_training: True})

        # 保存训练的loss数据
        loss_d_list.append(d_ls)
        loss_g_list.append(g_ls)
        print("", i, d_ls, g_ls)

        if i == 0:
            # 目的是检测生成图像是正确
            print("**************显示真实图像**************")
            # image_batch = (image_batch + 1)/ 2
            real_imgs = [real_img[:, :, :] for real_img in image_batch[0:64]]
            image_batch = imsave_image(real_imgs)
            plt.axis('off')
            imsave(os.path.join(sample_dir, 'real_%5d.jpg' % i), image_batch)

        if i % 100 == 0:
            gen_imgs = sess.run(fake_image, feed_dict={noise: z_samples, y_label:y_samples, is_training: False})
            # gen_imgs = (gen_imgs + 1) / 2
            # 2019 1 14生成的图像是8 * 8的
            imgs = [img[:, :, :] for img in gen_imgs[0:64]]
            gen_imgs = imsave_image(imgs)
            plt.axis('off')
            imsave(os.path.join(sample_dir, 'sample_%5d.jpg' % i), gen_imgs)



if __name__ == '__main__':
    # 程序开始首先指定使用那块GPU 不使用该语句的时候使用所有的GPU
    # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    # 这个是用来将原来的celeba转化成多大
    global width, high, batch_size, z_dim
    width = 32
    high = 32
    batch_size = 128
    z_dim = 100
    # 定义一个输出生成图像的文件夹
    sample_dir = 'Result/Celeba_ACGAN_DCGAN/cifar-10/'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    else:
        print("*********输出文件夹已经存在是否继续执行**************")
    train()

猜你喜欢

转载自blog.csdn.net/qq_41776781/article/details/86671556