tensorflow实现DCGAN

1、DCGAN的简单总结

【Paper】 :

            http://arxiv.org/abs/1511.06434

【github】 :

            https://github.com/Newmu/dcgan_code  theano

            https://github.com/carpedm20/DCGAN-tensorflow  tensorflow

            https://github.com/jacobgil/keras-dcgan    keras

            https://github.com/soumith/dcgan.torch  torch


DCGAN是继GAN之后比较好的改进,其主要的改进主要是在网络结构上,到目前为止,DCGAN的网络结构还是被广泛的使用,DCGAN极大的提升了GAN训练的稳定性以及生成结果质量。


DCGAN的生成器网络结构如上图所示, LSUN 场景模型中使用的DCGAN生成网络。一个100维度的均匀分布z映射到一个有很多特征映射的小空间范围卷积。一连串的四个微步幅卷积(在最近的一些论文中它们错误地称为去卷积),将高层表征转换为64*64像素的图像。相较原始的GAN,DCGAN几乎完全使用了卷积层代替全链接层,判别器几乎是和生成器对称的,从上图中我们可以看到,整个网络没有pooling层和上采样层的存在,实际上是使用了带步长(fractional-strided)的卷积代替了上采样,以增加训练的稳定性。

DCGAN能改进GAN训练稳定的原因主要有:

◆  使用步长卷积代替上采样层,卷积在提取图像特征上具有很好的作用,并且使用卷积代替全连接层。

◆  生成器G和判别器D中几乎每一层都使用batchnorm层,将特征层的输出归一化到一起,加速了训练,提升了训练的稳定性。(生成器的最后一层和判别器的第一层不加batchnorm)

◆  在判别器中使用leakrelu激活函数,而不是RELU,防止梯度稀疏,生成器中仍然采用relu,但是输出层采用tanh

◆  使用adam优化器训练,并且学习率最好是0.0002,(我也试过其他学习率,不得不说0.0002是表现最好的了)


主要改进总结:

1.将pooling层用convolutions替代。(对于判别模型,允许网络学习自己的空间下采样;対于生成模型,允许它学习自己的空间上采样)

2.在generator和discriminator上都使用batchnorm:

解决初始化差的问题

帮助梯度传播到每一层

防止generator把所有的样本都收敛到同一个点

3.在CNN中移除全连接层

4.在generator的除了输出层外的所有层使用ReLU,输出层采用tanh

5.在discriminator的所有层上使用LeakyReLU.

问题:

DCGAN虽然有很好的架构,但是对GAN训练稳定性来说是治标不治本,没有从根本上解决问题,而且训练的时候仍需要小心的平衡G,D的训练进程,往往是训练一个多次,训练另一个一次。

具体可以参见我的另一篇介绍 各种GAN原理对比

2、DCGAN的实现

声明:

本文主要参照于DCGAN-tensorflow

一共有4个文件,分别是main.py、model.py、ops.py、utils.py。

1)数据集

数据集可以自行制作,也可以直接下载

下载请执行

python download.py mnist celebA

代码实现的是下载mnist 数据集和celebA数据集。

下面主要介绍制作:

搜集原始数据集

首先是需要获取大量的动漫图像,这个可以利用爬虫爬取一个动漫网站:konachan.net的图片。爬虫的代码如下所示:


import requests  # http lib
from bs4 import BeautifulSoup  # climb lib
import os # operation system
import traceback # trace deviance

def download(url,filename):
    if os.path.exists(filename):
        print('file exists!')
        return
    try:
        r = requests.get(url,stream=True,timeout=60)
        r.raise_for_status()
        with open(filename,'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk: # filter out keep-alove new chunks
                    f.write(chunk)
                    f.flush()
        return filename
    except KeyboardInterrupt:
        if os.path.exists(filename):
            os.remove(filename)
        return KeyboardInterrupt
    except Exception:
        traceback.print_exc()
        if os.path.exists(filename):
            os.remove(filename)

if os.path.exists('imgs') is False:
    os.makedirs('imgs')

start = 1
end = 8000
for i in range(start, end+1):
    url = 'http://konachan.net/post?page=%d&tags=' % i
    html = requests.get(url).text # gain the web's information
    soup =  BeautifulSoup(html,'html.parser') # doc's string and jie xi qi
    for img in soup.find_all('img',class_="preview"):# 遍历所有preview类,找到img标签
        target_url = 'http:' + img['src']
        filename = os.path.join('imgs',target_url.split('/')[-1])
        download(target_url,filename)
    print('%d / %d' % (i,end))    

有了基本的图像了,但我们的目标是生成动漫头像,不需要整张图像,而且其他的信息会干扰到训练,所以需要进行人脸检测截取人脸图像。

人脸检测截取人脸

通过基于opencv的人脸检测分类器,参考于lbpcascade_animeface

首先,要使用这个分类器要先进行下载:

wget https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml

下载完成后,运行以下代码对图像进行人脸截取。

import cv2
import sys
import os.path
from glob import glob

def detect(filename,cascade_file="lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename)
    gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(
        gray,
        # detector options
        scaleFactor = 1.1,
        minNeighbors = 5,
        minSize = (48,48)
    )

    for i,(x,y,w,h) in enumerate(faces):
        face = image[y: y+h, x:x+w, :]
        face = cv2.resize(face,(96,96))
        save_filename = '%s.jpg' % (os.path.basename(filename).split('.')[0])
        cv2.imwrite("faces/"+save_filename,face)

if __name__ == '__main__':
    if os.path.exists('faces') is False:
        os.makedirs('faces')
    file_list = glob('imgs/*.jpg')
    for filename in file_list:
        detect(filename)

处理后的图像如下所示:

需要该数据集的我后面会将动漫头像放到网盘,大家可以自由下载。

下面直接上源码:

main.py

# !/usr/bin/python3
# -*- coding:utf-8 -*-
__auther__ = 'gavin'

import os #引用操作系统函数文件
import scipy.misc #引用scipy包misc模块 图像形式存取数组
import numpy as np #引用numpy包 矩阵计算
from model import DCGAN #引用model文件DCGAN类
from utils import pp, visualize, to_json, show_all_variables #引用utils文件pp对象,visualize, to_json, show_all_variables方法
import tensorflow as tf #引用tensorflow
flags = tf.app.flags #接受命令行传递参数,相当于接受argv。第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") #训练轮数 25
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") #adam优化器 学习速率 0.0002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") #adam优化器 动量(参数移动平均数) 0.5
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") #训练画像尺寸,默认无限大正数
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") #图像批大小 64
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") #输入图像高度 108 均衡的缩放图像(保持图像原始比例),使图片的两个坐标(宽、高)都大于等于 相应的视图坐标(负的内边距)。图像则位于视图的中央。
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") #输入图像宽度,None与高度相同
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") #输出图像高度 64
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") #输出图像宽度,None与高度相同
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") #数据集名称 celebA mnist lsun
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") #图片文件名的搜索扩展名
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") #检查点目录名
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") #图片样本保存目录名
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") #训练流程开关
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]") #训练流程开关
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") #可视化开关
FLAGS = flags.FLAGS
def main(_): #主程序
  pp.pprint(flags.FLAGS.__flags) #打印命令行参数
  if FLAGS.input_width is None: #如果没有配置输入图像宽度
    FLAGS.input_width = FLAGS.input_height #把输入图像高度作为宽度
  if FLAGS.output_width is None: #如果没有配置输出图像宽度
    FLAGS.output_width = FLAGS.output_height #把输出图像高度作为宽度
  if not os.path.exists(FLAGS.checkpoint_dir): #如果检查点目录不存在
    os.makedirs(FLAGS.checkpoint_dir) #创建检查点目录
  if not os.path.exists(FLAGS.sample_dir): #如果样本目录不存在
    os.makedirs(FLAGS.sample_dir) #创建样本目录
  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) #设置GPU显存占用比例
  run_config = tf.ConfigProto() #获取配置对象
  run_config.gpu_options.allow_growth = True #GPU显存占用按需增加
  with tf.Session(config=run_config) as sess: #指定配置构建会话
    if FLAGS.dataset == 'mnist': #如果指定数据集为mnist
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
            y_dim=10,  # 标签维度为10
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            is_crop=FLAGS.is_crop,
            sample_dir=FLAGS.sample_dir
        )
    else:
       dcgan = DCGAN( #构建DCGAN,不指定标签维度
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          is_crop=FLAGS.is_crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    show_all_variables() #显示所有参数
    if FLAGS.is_train: #如果是训练
      dcgan.train(FLAGS) #指定参数执行构建DCGAN 训练方法
    else: #如果是测试
      if not dcgan.load(FLAGS.checkpoint_dir)[0]: #在检查点目录没有检查点文件,即没有已训练好的模型
        raise Exception("[!] Train a model first, then run test mode") #抛出异常:请先训练模型再执行测试

    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], #JSON格式化:w,b,gbn
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])
    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION) #执行可视化方法,传入会话、DCGAN、配置参数,选项
if __name__ == '__main__': #如果直接执行本脚本文件,运行以下代码,一般作调试用。如果作为其它脚本模块引入,则不执行以下代码
  tf.app.run() #运行APP.run 解析FLAGS,执行main方法


main.py主要是调用前面定义好的模型、图像处理方法,来进行训练测试,程序的入口。

utils.py


"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
from time import gmtime, strftime
from six.moves import xrange

import tensorflow as tf
import tensorflow.contrib.slim as slim

# 方便打印数据结构信息
pp = pprint.PrettyPrinter()

# 三个参数乘积后开平方的倒数,应该是为了随机化用
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

# tf.trainable_variables返回的是需要训练的变量列表;
# 然后用tensorflow.contrib.slim中的model_analyzer.analyze_vars打印出所有与训练相关的变量信息
def show_all_variables():
  model_vars = tf.trainable_variables()
  slim.model_analyzer.analyze_vars(model_vars, print_info=True)


# 首先根据图像路径参数读取路径,根据灰度化参数选择是否进行灰度化。然后对图像参照输入的参数进行裁剪。
def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              is_crop=True, is_grayscale=False):
  image = imread(image_path, is_grayscale)
  return transform(image, input_height, input_width,
                   resize_height, resize_width, is_crop)

# 调用imsave(inverse_transform(images), size, image_path)函数并返回新图像。
def save_images(images, size, image_path):
  return imsave(inverse_transform(images), size, image_path)

# 调用cipy.misc.imread()函数,判断grayscale参数是否进行范围灰度化,并进行类型转换为np.float
def imread(path, is_grayscale = False):
  if (is_grayscale):
    return scipy.misc.imread(path, flatten = True).astype(np.float)
  else:
    return scipy.misc.imread(path).astype(np.float)

# 调用inverse_transform(images)函数,并返回新图像。
def merge_images(images, size):
  return inverse_transform(images)

# 首先获取image的高和宽。然后判断image是RGB图还是灰度图,以分别进行不同的处理。
# 如果通道数是3或4,则对每一批次(如,batch_size=64)的所有图像,
# 用0初始化一张原始图像放大8*8的图像,然后循环,依次将所有图像填入大图像,并且返回这张大图像。
# 如果通道数是1,也是一样,只不过填入图像的时候只填一个通道的信息。如果不是上述两种情况,则抛出错误提示。
def merge(images, size):
  h, w = images.shape[1], images.shape[2]
  if (images.shape[3] in (3,4)):
    c = images.shape[3]
    img = np.zeros((h * size[0], w * size[1], c))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w, :] = image
    return img
  elif images.shape[3]==1:
    img = np.zeros((h * size[0], w * size[1]))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
    return img
  else:
    raise ValueError('in merge(images,size) images parameter '
                     'must have dimensions: HxW or HxWx3 or HxWx4')

''' 
def merge(images, size):
  h, w = images.shape[1], images.shape[2]
  img = np.zeros((h * size[0], w * size[1], 3))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image
  return img

'''

# 首先将merge()函数返回的图像,用 np.squeeze()函数移除长度为1的轴。
# 然后利用scipy.misc.imsave()函数将新图像保存到指定路径中。
def imsave(images, size, path):
  return scipy.misc.imsave(path, np.squeeze(merge(images, size)))

# 对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像。
def center_crop(x, crop_h, crop_w,
                resize_h=64, resize_w=64):
  if crop_w is None:
    crop_w = crop_h
  h, w = x.shape[:2]
  j = int(round((h - crop_h)/2.))
  i = int(round((w - crop_w)/2.))
  return scipy.misc.imresize(
      x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])


# 对输入的图像进行裁剪,如果crop为true,则使用center_crop()函数,对图像的H和W与crop的H和W相减,
# 得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像;
# 否则不对图像进行其他操作,直接scipy.misc.resize为64*64大小的图像。
# 最后返回图像。
def transform(image, input_height, input_width, 
              resize_height=64, resize_width=64, is_crop=True):
  if is_crop:
    cropped_image = center_crop(
      image, input_height, input_width, 
      resize_height, resize_width)
  else:
    cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
  return np.array(cropped_image)/127.5 - 1.

# 对图像进行翻转后返回新图像。
def inverse_transform(images):
  return (images+1.)/2.

def to_json(output_path, *layers):
  with open(output_path, "w") as layer_f:
    lines = ""
    for w, b, bn in layers:
      layer_idx = w.name.split('/')[0].split('h')[1]

      B = b.eval()

      if "lin/" in w.name:
        W = w.eval()
        depth = W.shape[1]
      else:
        W = np.rollaxis(w.eval(), 2, 0)
        depth = W.shape[0]

      biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
      if bn != None:
        gamma = bn.gamma.eval()
        beta = bn.beta.eval()

        gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
        beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
      else:
        gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
        beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}

      if "lin/" in w.name:
        fs = []
        for w in W.T:
          fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})

        lines += """
          var layer_%s = {
            "layer_type": "fc", 
            "sy": 1, "sx": 1, 
            "out_sx": 1, "out_sy": 1,
            "stride": 1, "pad": 0,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
      else:
        fs = []
        for w_ in W:
          fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})

        lines += """
          var layer_%s = {
            "layer_type": "deconv", 
            "sy": 5, "sx": 5,
            "out_sx": %s, "out_sy": %s,
            "stride": 2, "pad": 1,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
               W.shape[0], W.shape[3], biases, gamma, beta, fs)
    layer_f.write(" ".join(lines.replace("'","").split()))


# 利用moviepy.editor模块来制作动图,为了可视化用的。函数又定义了一个函数make_frame(t),
# 首先根据图像集的长度和持续的时间做一个除法,然后返回每帧图像。最后视频修剪并制作成GIF动画。
def make_gif(images, fname, duration=2, true_image=False):
  import moviepy.editor as mpy

  def make_frame(t):
    try:
      x = images[int(len(images)/duration*t)]
    except:
      x = images[-1]

    if true_image:
      return x.astype(np.uint8)
    else:
      return ((x+1)/2*255).astype(np.uint8)

  clip = mpy.VideoClip(make_frame, duration=duration)
  clip.write_gif(fname, fps = len(images) / duration)

# 分为0、1、2、3、4种option。
# 如果option=0,则之间显示生产的样本‘
# 如果option=1,根据不同数据集不一样的处理,并利用前面的save_images()函数将sample保存下来;
# 等等。本次在main.py中选用option=1。
def visualize(sess, dcgan, config, option):
  if option == 0:
    z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
  elif option == 1:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(100):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx))
  elif option == 2:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in [random.randint(0, 99) for _ in xrange(100)]:
      print(" [*] %d" % idx)
      z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
      z_sample = np.tile(z, (config.batch_size, 1))
      #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      try:
        make_gif(samples, './samples/test_gif_%s.gif' % (idx))
      except:
        save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
  elif option == 3:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(100):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
      make_gif(samples, './samples/test_gif_%s.gif' % (idx))
  elif option == 4:
    image_set = []
    values = np.arange(0, 1, 1./config.batch_size)

    for idx in xrange(100):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample): z[idx] = values[kdx]

      image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
      make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))

    new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
        for idx in range(64) + range(63, -1, -1)]
    make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)


# 首先获取图像数量的开平方后向下取整的h和向上取整的w,
# 然后设置一个assert断言,如果h*w与图像数量相等,则返回h和w,否则断言错误提示。

def image_manifold_size(num_images):
  manifold_h = int(np.floor(np.sqrt(num_images)))
  manifold_w = int(np.ceil(np.sqrt(num_images)))
  assert manifold_h * manifold_w == num_images
  return manifold_h, manifold_w


这份代码主要是定义了各种对图像处理的函数,相当于其他3个文件的头文件。

step0:首先定义了一个pp = pprint.PrettyPrinter(),以方便打印数据结构信息,详细信息可见这篇博客

step1:定义了get_stddev函数,是三个参数乘积后开平方的倒数,应该是为了随机化用。

step2:定义show_all_variables()函数。首先,tf.trainable_variables返回的是需要训练的变量列表;然后用tensorflow.contrib.slim中的model_analyzer.analyze_vars打印出所有与训练相关的变量信息

注:从step3-step11,都是在定义一些图像处理的函数,它们之间相互调用。

ops.py


'''
这个文件主要定义了一些变量连接的函数、批处理规范化的函数、卷积函数、解卷积函数、激励函数、线性运算函数
'''
import math
import numpy as np 
import tensorflow as tf

from tensorflow.python.framework import ops

from utils import *

try:
  image_summary = tf.image_summary
  scalar_summary = tf.scalar_summary
  histogram_summary = tf.histogram_summary
  merge_summary = tf.merge_summary
  SummaryWriter = tf.train.SummaryWriter
except:
  image_summary = tf.summary.image
  scalar_summary = tf.summary.scalar
  histogram_summary = tf.summary.histogram
  merge_summary = tf.summary.merge
  SummaryWriter = tf.summary.FileWriter

if "concat_v2" in dir(tf):
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat(tensors, axis, *args, **kwargs)

'''
定义一个batch_norm类,包含两个函数init和call函数。
首先在init(self, epsilon=1e-5, momentum = 0.9, name=”batch_norm”)函数中,
定义一个name参数名字的变量,初始化self变量epsilon、momentum 、name。
在call(self, x, train=True)函数中,利用tf.contrib.layers.batch_norm函数批处理规范化。

'''
class batch_norm(object):
  def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
    with tf.variable_scope(name):
      self.epsilon  = epsilon
      self.momentum = momentum
      self.name = name

  def __call__(self, x, train=True):
    return tf.contrib.layers.batch_norm(x,
                      decay=self.momentum, 
                      updates_collections=None,
                      epsilon=self.epsilon,
                      scale=True,
                      is_training=train,
                      scope=self.name)


# 连接x,y与Int32型的[x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]维度的张量乘积。
def conv_cond_concat(x, y):
  """Concatenate conditioning vector on feature map axis."""
  x_shapes = x.get_shape()
  y_shapes = y.get_shape()
  return concat([
    x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)

# 卷积函数:获取随机正态分布权值、实现卷积、获取初始偏置值,获取添加偏置值后的卷积变量并返回。
def conv2d(input_, output_dim, 
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2d"):
  with tf.variable_scope(name):
    w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
    conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

    return conv


# 解卷积函数:获取随机正态分布权值、解卷积,获取初始偏置值,获取添加偏置值后的卷积变量,
# 判断with_w是否为真,真则返回解卷积、权值、偏置值,否则返回解卷积。

def deconv2d(input_, output_shape,
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="deconv2d", with_w=False):
  with tf.variable_scope(name):
    # filter : [height, width, output_channels, in_channels]
    w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
              initializer=tf.random_normal_initializer(stddev=stddev))
    
    try:
      deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])

    # Support for verisons of TensorFlow before 0.7.0
    except AttributeError:
      deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])

    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

    if with_w:
      return deconv, w, biases
    else:
      return deconv

# 定义一个lrelu激励函数。
def lrelu(x, leak=0.2, name="lrelu"):
  return tf.maximum(x, leak*x)

# 进行线性运算,获取一个随机正态分布矩阵,获取初始偏置值,如果with_w为真,则返回xw+b,权值w和偏置值b;否则返回xw+b。
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
  shape = input_.get_shape().as_list()

  with tf.variable_scope(scope or "Linear"):
    matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                 tf.random_normal_initializer(stddev=stddev))
    bias = tf.get_variable("bias", [output_size],
      initializer=tf.constant_initializer(bias_start))
    if with_w:
      return tf.matmul(input_, matrix) + bias, matrix, bias
    else:
      return tf.matmul(input_, matrix) + bias
这个文件主要定义了一些变量连接的函数、批处理规范化的函数、卷积函数、解卷积函数、激励函数、线性运算函数。


model.py

from __future__ import division
import os
import time
import math
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

from ops import *
from utils import *

def conv_out_size_same(size, stride):
  return math.ceil(float(size) / float(stride))


# 定义了DCGAN类,剩余代码都是在写DCGAN类,所以下面几步都是在这个类里面定义进行的
class DCGAN(object):
  def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
         batch_size=64, sample_num = 64, output_height=64, output_width=64,
         y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
         gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
         input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None):
    """
    Args:
      sess: TensorFlow session
      batch_size: The size of batch. Should be specified before training.
      y_dim: (optional) Dimension of dim for y. [None]
      z_dim: (optional) Dimension of dim for Z. [100]
      gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
      df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
      gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
      dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
      c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]

      对一些默认的参数进行初始化。
      包括session、crop、批处理大小batch_size、样本数量sample_num、输入与输出的高和宽、各种维度、
      生成器与判别器的批处理、数据集名字、灰度值、构建模型函数,需要注意的是,要判断数据集的名字是否是mnist,
      是的话则直接用load_mnist()函数加载数据,否则需要从本地data文件夹中读取数据,并将图像读取为灰度图
    """
    self.sess = sess
    self.is_crop = is_crop
    self.is_grayscale = (c_dim == 1)

    self.batch_size = batch_size
    self.sample_num = sample_num

    self.input_height = input_height
    self.input_width = input_width
    self.output_height = output_height
    self.output_width = output_width

    self.y_dim = y_dim
    self.z_dim = z_dim

    self.gf_dim = gf_dim
    self.df_dim = df_dim

    self.gfc_dim = gfc_dim
    self.dfc_dim = dfc_dim

    self.c_dim = c_dim

    # batch normalization : deals with poor initialization helps gradient flow
    self.d_bn1 = batch_norm(name='d_bn1')
    self.d_bn2 = batch_norm(name='d_bn2')

    if not self.y_dim:
      self.d_bn3 = batch_norm(name='d_bn3')

    self.g_bn0 = batch_norm(name='g_bn0')
    self.g_bn1 = batch_norm(name='g_bn1')
    self.g_bn2 = batch_norm(name='g_bn2')

    if not self.y_dim:
      self.g_bn3 = batch_norm(name='g_bn3')

    self.dataset_name = dataset_name
    self.input_fname_pattern = input_fname_pattern
    self.checkpoint_dir = checkpoint_dir

    if self.dataset_name == 'mnist':
      self.data_X, self.data_y = self.load_mnist()
      self.c_dim = self.data_X[0].shape[-1]
    else:
      self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))
      imreadImg = imread(self.data[0])
      if len(imreadImg.shape) >= 3:  # check if image is a non-grayscale image by checking channel number
        self.c_dim = imread(self.data[0]).shape[-1]
      else:
        self.c_dim = 1

    self.grayscale = (self.c_dim == 1)

    self.build_model()

  def build_model(self):
    if self.y_dim:
      self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')

    if self.is_crop:
      image_dims = [self.output_height, self.output_width, self.c_dim]
    else:
      image_dims = [self.input_height, self.input_height, self.c_dim]

    self.inputs = tf.placeholder(
      tf.float32, [self.batch_size] + image_dims, name='real_images')
    self.sample_inputs = tf.placeholder(
      tf.float32, [self.sample_num] + image_dims, name='sample_inputs')

    # 真实数据的向量
    inputs = self.inputs
    sample_inputs = self.sample_inputs

    # 定义并初始化生成器用到的噪音z,z_sum。
    self.z = tf.placeholder(
      tf.float32, [None, self.z_dim], name='z')
    self.z_sum = histogram_summary("z", self.z)

    # 再次判断y_dim,如果为真,用噪音z和标签y初始化生成器G、用输入inputs初始化判别器D和D_logits、样本、
    # 用G和y初始化D_和D_logits;如果为假,跟上面一样初始化各种变量,只不过都没有标签y。
    if self.y_dim:
      self.G = self.generator(self.z, self.y)
      self.D, self.D_logits = \
          self.discriminator(inputs, self.y, reuse=False)

      self.sampler = self.sampler(self.z, self.y)
      self.D_, self.D_logits_ = \
          self.discriminator(self.G, self.y, reuse=True)
    else:
      self.G = self.generator(self.z)
      self.D, self.D_logits = self.discriminator(inputs)

      self.sampler = self.sampler(self.z)
      self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)

    self.d_sum = histogram_summary("d", self.D)
    self.d__sum = histogram_summary("d_", self.D_)
    self.G_sum = image_summary("G", self.G)
    '''
    self.d_loss_real = tf.reduce_mean(
      tf.nn.sigmoid_cross_entropy_with_logits(
        logits=self.D_logits, labels=tf.ones_like(self.D))) 
    self.d_loss_fake = tf.reduce_mean(
      tf.nn.sigmoid_cross_entropy_with_logits(
        logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
    self.g_loss = tf.reduce_mean(
      tf.nn.sigmoid_cross_entropy_with_logits(
        logits=self.D_logits_, labels=tf.ones_like(self.D_)))
    '''

    def sigmoid_cross_entropy_with_logits(x, y):
      try:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
      except:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)

    self.d_loss_real = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
    self.d_loss_fake = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
    self.g_loss = tf.reduce_mean(
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

    self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
    self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
                          
    self.d_loss = self.d_loss_real + self.d_loss_fake

    self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
    self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

    t_vars = tf.trainable_variables()

    self.d_vars = [var for var in t_vars if 'd_' in var.name]
    self.g_vars = [var for var in t_vars if 'g_' in var.name]

    self.saver = tf.train.Saver()

  def train(self, config):
    """Train DCGAN"""

    d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.d_loss, var_list=self.d_vars)
    g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.g_loss, var_list=self.g_vars)
    try:
      tf.global_variables_initializer().run()
    except:
      tf.initialize_all_variables().run()

    self.g_sum = merge_summary([self.z_sum, self.d__sum,
      self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
    self.d_sum = merge_summary(
        [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
    self.writer = SummaryWriter("./logs", self.sess.graph)

    sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))


    if config.dataset == 'mnist':
      sample_inputs = self.data_X[0:self.sample_num]
      sample_labels = self.data_y[0:self.sample_num]
    else:
      sample_files = self.data[0:self.sample_num]
      sample = [
        get_image(sample_file,
                  input_height=self.input_height,
                  input_width=self.input_width,
                  resize_height=self.output_height,
                  resize_width=self.output_width,
                  is_crop=self.is_crop,
                  is_grayscale=self.is_grayscale)  for sample_file in sample_files ]

      if (self.grayscale):
        sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
      else:
        sample_inputs = np.array(sample).astype(np.float32)
  
    counter = 1
    start_time = time.time()

    could_load, checkpoint_counter = self.load(self.checkpoint_dir)
    if could_load:
      counter = checkpoint_counter
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")


    for epoch in xrange(config.epoch):
      if config.dataset == 'mnist':
        batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
      else:      
        self.data = glob(os.path.join(
          "./data", config.dataset, self.input_fname_pattern))
        batch_idxs = min(len(self.data), config.train_size) // config.batch_size

      for idx in xrange(0, batch_idxs):
        if config.dataset == 'mnist':
          batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size]
          batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]
        else:
          batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size]
          batch = [
              get_image(batch_file,
                        input_height=self.input_height,
                        input_width=self.input_width,
                        resize_height=self.output_height,
                        resize_width=self.output_width,
                        is_crop=self.is_crop,
                        is_grayscale=self.is_grayscale) for batch_file in batch_files]
          if (self.is_grayscale):
            batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
          else:
            batch_images = np.array(batch).astype(np.float32)

        batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
              .astype(np.float32)

        if config.dataset == 'mnist':
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ 
              self.inputs: batch_images,
              self.z: batch_z,
              self.y:batch_labels,
            })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={
              self.z: batch_z, 
              self.y:batch_labels,
            })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z, self.y:batch_labels })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({
              self.z: batch_z, 
              self.y:batch_labels
          })
          errD_real = self.d_loss_real.eval({
              self.inputs: batch_images,
              self.y:batch_labels
          })
          errG = self.g_loss.eval({
              self.z: batch_z,
              self.y: batch_labels
          })
        else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

        counter += 1
        print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
          % (epoch, idx, batch_idxs,
            time.time() - start_time, errD_fake+errD_real, errG))

        if np.mod(counter, 100) == 1:
          if config.dataset == 'mnist':
            samples, d_loss, g_loss = self.sess.run(
              [self.sampler, self.d_loss, self.g_loss],
              feed_dict={
                  self.z: sample_z,
                  self.inputs: sample_inputs,
                  self.y:sample_labels,
              }
            )
            save_images(samples, [8, 8],
                  './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
            print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
          else:
            try:
              samples, d_loss, g_loss = self.sess.run(
                [self.sampler, self.d_loss, self.g_loss],
                feed_dict={
                    self.z: sample_z,
                    self.inputs: sample_inputs,
                },
              )
              save_images(samples, [8, 8],
                    './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
              print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
            except:
              print("one pic error!...")

        if np.mod(counter, 100) == 2:
          self.save(config.checkpoint_dir, counter)

  def discriminator(self, image, y=None, reuse=False):
    with tf.variable_scope("discriminator") as scope:
      if reuse:
        scope.reuse_variables()

      if not self.y_dim:
        h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
        h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
        h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
        h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')

        return tf.nn.sigmoid(h4), h4
      else:
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        x = conv_cond_concat(image, yb)

        h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
        h0 = conv_cond_concat(h0, yb)

        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
        h1 = tf.reshape(h1, [self.batch_size, -1])      
        h1 = concat([h1, y], 1)
        
        h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
        h2 = concat([h2, y], 1)

        h3 = linear(h2, 1, 'd_h3_lin')
        
        return tf.nn.sigmoid(h3), h3

  def generator(self, z, y=None):
    with tf.variable_scope("generator") as scope:
      if not self.y_dim:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

        # project `z` and reshape
        self.z_, self.h0_w, self.h0_b = linear(
            z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)

        self.h0 = tf.reshape(
            self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])
        h0 = tf.nn.relu(self.g_bn0(self.h0))

        self.h1, self.h1_w, self.h1_b = deconv2d(
            h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
        h1 = tf.nn.relu(self.g_bn1(self.h1))

        h2, self.h2_w, self.h2_b = deconv2d(
            h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
        h2 = tf.nn.relu(self.g_bn2(h2))

        h3, self.h3_w, self.h3_b = deconv2d(
            h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
        h3 = tf.nn.relu(self.g_bn3(h3))

        h4, self.h4_w, self.h4_b = deconv2d(
            h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)

        return tf.nn.tanh(h4)
      else:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.expand_dims(tf.expand_dims(y, 1),2)
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(
            self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])

        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
            [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(
            deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def sampler(self, z, y=None):
    with tf.variable_scope("generator") as scope:
      scope.reuse_variables()

      if not self.y_dim:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

        # project `z` and reshape
        h0 = tf.reshape(
            linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
            [-1, s_h16, s_w16, self.gf_dim * 8])
        h0 = tf.nn.relu(self.g_bn0(h0, train=False))

        h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
        h1 = tf.nn.relu(self.g_bn1(h1, train=False))

        h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
        h2 = tf.nn.relu(self.g_bn2(h2, train=False))

        h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
        h3 = tf.nn.relu(self.g_bn3(h3, train=False))

        h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')

        return tf.nn.tanh(h4)
      else:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(
            deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def load_mnist(self):
    data_dir = os.path.join("./data", self.dataset_name)
    
    fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    trY = np.asarray(trY)
    teY = np.asarray(teY)
    
    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)
    
    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)
    
    y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)
    for i, label in enumerate(y):
      y_vec[i,y[i]] = 1.0
    
    return X/255.,y_vec

  @property
  def model_dir(self):
    return "{}_{}_{}_{}".format(
        self.dataset_name, self.batch_size,
        self.output_height, self.output_width)
      
  def save(self, checkpoint_dir, step):
    model_name = "DCGAN.model"
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
            os.path.join(checkpoint_dir, model_name),
            global_step=step)

  '''
  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      print(" [*] Success to read {}".format(ckpt_name))
      return True
    else:
      print(" [*] Failed to find a checkpoint")
      return False
  '''

  def load(self, checkpoint_dir):
    import re
    print(" [*] Reading checkpoints...")
    checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
      print(" [*] Success to read {}".format(ckpt_name))
      return True, counter
    else:
      print(" [*] Failed to find a checkpoint")
      return False, 0

这个文件就是DCGAN模型定义的函数。调用了utils.py文件和ops.py文件。

step0:定义conv_out_size_same(size, stride)函数。大小和步幅。

step1:然后是定义了DCGAN类,剩余代码都是在写DCGAN类,所以下面几步都是在这个类里面定义进行的。

step2:定义类的初始化函数 init。主要是对一些默认的参数进行初始化。包括session、crop、批处理大小batch_size、样本数量sample_num、输入与输出的高和宽、各种维度、生成器与判别器的批处理、数据集名字、灰度值、构建模型函数,需要注意的是,要判断数据集的名字是否是mnist,是的话则直接用load_mnist()函数加载数据,否则需要从本地data文件夹中读取数据,并将图像读取为灰度图。

step3:定义构建模型函数build_model(self)。

  1. 首先判断y_dim,然后用tf.placeholder占位符定义并初始化y。
  2. 判断crop是否为真,是的话是进行测试,图像维度是输出图像的维度;否则是输入图像的维度。
  3. 利用tf.placeholder定义inputs,是真实数据的向量。
  4. 定义并初始化生成器用到的噪音z,z_sum。
  5. 再次判断y_dim,如果为真,用噪音z和标签y初始化生成器G、用输入inputs初始化判别器D和D_logits、样本、用G和y初始化D_和D_logits;如果为假,跟上面一样初始化各种变量,只不过都没有标签y。
  6. 将5中的D、D_、G分别放在d_sum、d__sum、G_sum。
  7. 定义sigmoid交叉熵损失函数sigmoid_cross_entropy_with_logits(x, y)。都是调用tf.nn.sigmoid_cross_entropy_with_logits函数,只不过一个是训练,y是标签,一个是测试,y是目标。
  8. 定义各种损失值。真实数据的判别损失值d_loss_real、虚假数据的判别损失值d_loss_fake、生成器损失值g_loss、判别器损失值d_loss。
  9. 定义训练的所有变量t_vars。
  10. 定义生成和判别的参数集。
  11. 最后是保存。

step4:定义训练函数train(self, config)。

  1. 定义判别器优化器d_optim和生成器优化器g_optim。
  2. 变量初始化。
  3. 分别将关于生成器和判别器有关的变量各合并到一个变量中,并写入事件文件中。
  4. 噪音z初始化。
  5. 根据数据集是否为mnist的判断,进行输入数据和标签的获取。这里使用到了utils.py文件中的get_image函数。
  6. 定义计数器counter和起始时间start_time。
  7. 加载检查点,并判断加载是否成功。
  8. 开始for epoch in xrange(config.epoch)循环训练。先判断数据集是否是mnist,获取批处理的大小。
  9. 开始for idx in xrange(0, batch_idxs)循环训练,判断数据集是否是mnist,来定义初始化批处理图像和标签。
  10. 定义初始化噪音z。
  11. 判断数据集是否是mnist,来更新判别器网络和生成器网络,这里就不管mnist数据集是怎么处理的,其他数据集是,运行生成器优化器两次,以确保判别器损失值不会变为0,然后是判别器真实数据损失值和虚假数据损失值、生成器损失值。
  12. 输出本次批处理中训练参数的情况,首先是第几个epoch,第几个batch,训练时间,判别器损失值,生成器损失值。
  13. 每100次batch训练后,根据数据集是否是mnist的不同,获取样本、判别器损失值、生成器损失值,调用utils.py文件的save_images函数,保存训练后的样本,并以epoch、batch的次数命名文件。然后打印判别器损失值和生成器损失值。
  14. 每500次batch训练后,保存一次检查点。

step5:定义判别器函数discriminator(self, image, y=None, reuse=False)。

  1. 利用with tf.variable_scope(“discriminator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 如果为假,则直接设置5层,前4层为使用lrelu激活函数的卷积层,最后一层是使用线性层,最后返回h4和sigmoid处理后的h4。
  4. 如果为真,则首先将Y_dim变为yb,然后利用ops.py文件中的conv_cond_concat函数,连接image与yb得到x,然后设置4层网络,前3层是使用lrelu激励函数的卷积层,最后一层是线性层,最后返回h3和sigmoid处理后的h3。

step6:定义生成器函数generator(self, z, y=None)。

  1. 利用with tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 根据y_dim是否为真,进行判别网络的设置。
  3. 如果为假:首先获取输出的宽和高,然后根据这一值得到更多不同大小的高和宽的对。然后获取h0层的噪音z,权值w,偏置值b,然后利用relu激励函数。h1层,首先对h0层解卷积得到本层的权值和偏置值,然后利用relu激励函数。h2、h3等同于h1。h4层,解卷积h3,然后直接返回使用tanh激励函数后的h4。
  4. 如果为真:首先也是获取输出的高和宽,根据这一值得到更多不同大小的高和宽的对。然后获取yb和噪音z。h0层,使用relu激励函数,并与1连接。h1层,对线性全连接后使用relu激励函数,并与yb连接。h2层,对解卷积后使用relu激励函数,并与yb连接。最后返回解卷积、sigmoid处理后的h2。

step7:定义sampler(self, z, y=None)函数。

  1. 利用tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 根据y_dim是否为真,进行判别网络的设置。
  4. 然后就跟生成器差不多,不在赘述。

step8:定义load_mnist(self)函数。这个主要是针对mnist数据集设置的,所以暂且不考虑,过。

step9:定义model_dir(self)函数。返回数据集名字,batch大小,输出的高和宽。

step10:定义save(self, checkpoint_dir, step)函数。保存训练好的模型。创建检查点文件夹,如果路径不存在,则创建;然后将其保存在这个文件夹下。

step11:定义load(self, checkpoint_dir)函数。读取检查点,获取路径,重新存储检查点,并且计数。打印成功读取的提示;如果没有路径,则打印失败的提示。

以上,就是model.py所有内容,主要是定义了DCGAN的类,完成了生成判别网络的实现。


训练

现在,整个4个文件都已经分析完毕,开始运行。

step0:由于我们使用的动漫人脸数据集,所以我们需要在源文件的路径下,建一个data文件夹,然后将放有数据的文件夹放在这个data文件夹中,如下所示。




运行命令如下,需要制定各种参数,如我们的输入数据的高宽,输出的高宽,是哪个数据集,是否测试、训练,运行几个epoch。

python3 main.py --input_height 96 --output_height 48 --dataset faces --is_crop True --is_train True --epoch 10

当然如果你使用的是IDE开发环境,同样需要配置参数,比如我的pycharm环境,配置上述参数后,运行结果如下:



贴几张生成结果图:

epcho 1 

train_00_0099.png



train_00_0299.png


train_00_0499.png



继续训练.....稍后再贴了


epoch 10


epoch 20



epoch 200



epoch 300





猜你喜欢

转载自blog.csdn.net/Gavinmiaoc/article/details/80066700