CycleGAN算法原理(附源代码,可直接运行)

前言

CycleGAN是在今年三月底放在arxiv(论文地址CycleGAN)的一篇文章,文章名为Learning to Discover Cross-Domain Relations with Generative Adversarial Networks,同一时期还有两篇非常类似的DualGAN(论文地址:DualGAN)和DiscoGAN(论文地址:DiscoGAN),简单来说,它们的功能就是:自动将某一类图片转换成另外一类图片。不同于GAN和CGAN(上节已经介绍过),CycleGAN不需要配对的训练图像。当然了配对图像也完全可以,不过大多时候配对图像比较难获取。

这里写图片描述
这里写图片描述
配对图像
这里写图片描述
未配对的图像

CycleGAN能做什么?

CycleGAN可以完成GAN和CGAN的工作,如上述配对图像所示,可以从一个特定的场景模式图生成另外一个场景模式图,这两张场景模式中的物体完全相同。除此之外,CycleGAN还可以完成从一个模式到另外一个模式的转换,转换的过程中,物体发生了改变,比如下面的图像中从猫到狗,从男人到女人。

这里写图片描述
这里写图片描述

CycleGAN算法原理

如下图所示CycleGAN其实是由两个判别器( D x D y )和两个生成器(G和F)组成,但是为什么要连两个生成器和两个判别器呢?论文中说,是为了避免所有的X都被映射到同一个Y,比如所有男人的图像都映射到范冰冰的图像上,这显然不合理,所以为了避免这种情况,论文采用了两个生成器的方式,既能满足X->Y的映射,又能满足Y->X的映射,这一点其实就是变分自编码器VAE的思想,是为了适应不同输入图像产生不同输出图像。那么下面的四个公式也很清楚了,(1)是判别器Y对X->Y的映射G的损失,判别器X对Y->X映射的损失也非常类似(2)是两个生成器的循环损失,这里其实是 L 1 损失(3)是总损失(4)是对总损失进行优化,先优化D然后优化G和F,这一点和GAN类似

这里写图片描述
这里写图片描述
(1)
这里写图片描述
(2)
这里写图片描述
(3)
这里写图片描述
(4)

源代码

训练源代码

import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 128, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                     'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                       '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10.0,
                        'weight for forward cycle loss (X->Y->X), default: 10.0')
tf.flags.DEFINE_integer('lambda2', 10.0,
                        'weight for backward cycle loss (Y->X->Y), default: 10.0')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                      'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                      'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                      'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                        'number of gen filters in first conv layer, default: 64')

tf.flags.DEFINE_string('X', 'tfrecords/apple.tfrecords',
                       'X tfrecords file for training, default: tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'tfrecords/orange.tfrecords',
                       'Y tfrecords file for training, default: tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                        'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')


def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)

def main(unused_argv):
  train()

if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO)
  tf.app.run()

测试源代码

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 256
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('model', 'model/apple2orange.pb', 'model path (.pb)')
tf.flags.DEFINE_string('input', 'samples/real_apple2orange_4.jpg', 'input image path (.jpg)')
tf.flags.DEFINE_string('output', 'output/output_sample3.jpg', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '256', 'image size, default: 256')

def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated)

def main(unused_argv):
  inference()

if __name__ == '__main__':
  tf.app.run()

实验结果

在这里是以相同物体不同模式下的数据集做训练(由于没有找到不同物体不同模式下的数据,当然你也可以自己做),从苹果到橘子的训练,测试结果如下:

这里写图片描述

从上图可以看出,苹果的颜色已经改成橘色,效果得到了体现。
源代码链接:CycleGAN source code

猜你喜欢

转载自blog.csdn.net/qq_29462849/article/details/80554706
今日推荐