cyclegan——训练+预测1

参考链接:
原文:https://blog.csdn.net/Exploer_TRY/article/details/81841826
代码详解:https://blog.csdn.net/greepex/article/details/86360726
公式理解:https://blog.csdn.net/Teeyohuang/article/details/82729047

下面是自己经过注释的代码:

由于代码量较大,本文只发4个关键执行文件,剩余几个文件在下一篇文章发出来。

1.build_data.py(生成tf文件) 

import tensorflow as tf
import random
import os
try:
  from os import scandir
except ImportError:
  # Python 2 polyfill module
  from scandir import scandir
    
#把全局变量用FLAGS代替,方便引用
FLAGS = tf.flags.FLAGS
#定义全局变量'X_input_dir',即输入图片x域的路径
tf.flags.DEFINE_string('X_input_dir', 'face/trainA/',
                       'X input directory, default: data/apple2orange/trainA')
#定义全局变量'Y_input_dir',即输入图片y域的路径
tf.flags.DEFINE_string('Y_input_dir', 'face/trainB/',
                       'Y input directory, default: data/apple2orange/trainB')
#定义全局变量'X_output_file',即图片x域转成tfrecorda文件后保存的路径
tf.flags.DEFINE_string('X_output_file', 'face/real.tfrecords',
                       'X output tfrecords file, default: data/tfrecords/apple.tfrecords')
#定义全局变量'Y_output_file',即图片y域转成tfrecorda文件后保存的路径
tf.flags.DEFINE_string('Y_output_file', 'face/blur.tfrecords',
                       'Y output tfrecords file, default: data/tfrecords/orange.tfrecords')


def data_reader(input_dir, shuffle=True):
  #预先定义一个空数列
  file_paths = []
  #遍历输入路径的文件夹中的文件
  for img_file in scandir(input_dir):
    #如果是jpg格式并且属于文件(非文件夹),则下一步,把文件的路径填到file_paths数列中
    if img_file.name.endswith('.jpg') and img_file.is_file():
      file_paths.append(img_file.path)

  #如果随机打乱为真
  if shuffle:
    #索引范围为:0-文件数
    shuffled_index = list(range(len(file_paths)))
    #让随机结果可重现
    random.seed(12345)
    #把索引的下标重新打乱
    random.shuffle(shuffled_index)
    #随机选择图片,然后放到数列中
    file_paths = [file_paths[i] for i in shuffled_index]
  #把数列返回出去
  return file_paths

#把对象转换为整形属性
def _int64_feature(value):
  #判断value的类型是否跟list的一样
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

#把对象转换为字符串型属性
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _convert_to_example(file_path, image_buffer):
  #从文件的路径中分割出文件名
  file_name = file_path.split('/')[-1]
  #把对象进行转码
  example = tf.train.Example(features=tf.train.Features(feature={
    #调用_bytes_feature函数,把对象转为字符串型属性
      'image/file_name': _bytes_feature(tf.compat.as_bytes(os.path.basename(file_name))),
      'image/encoded_image': _bytes_feature((image_buffer))
    }))
  return example

def data_writer(input_dir, output_file):
  #调用data_reader函数,随机获取一张图片的路径
  file_paths = data_reader(input_dir)
  #遍历输出路径的文件夹是否存在,若没有则新建
  #遍历output_file,得到文件所在的文件夹的绝对路径,如:face/real.tfrecords---face/
  output_dir = os.path.dirname(output_file)   
  try:
    os.makedirs(output_dir)
  except os.error as e:
    pass
  #获取文件的数量
  images_num = len(file_paths)
  #创建tf文件
  writer = tf.python_io.TFRecordWriter(output_file)
  #遍历每一个图片文件
  for i in range(len(file_paths)):
    file_path = file_paths[i]
    #读取图片
    with tf.gfile.FastGFile(file_path, 'rb') as f:
      image_data = f.read()
    #创建一个整例,用于把图片和文字进行转码
    example = _convert_to_example(file_path, image_data)
    #把转码之后的信息,写入tfrecords文件
    writer.write(example.SerializeToString())   #把example序列化为一个字符串
    #做一个步长判断,每500步打印一次信息
    if i % 500 == 0:
      print("Processed {}/{}.".format(i, images_num))
  print("Done.")
  #关闭tfrecords文件
  writer.close()

def main(unused_argv):
  print("Convert X data to tfrecords...")
  #开始调用函数:data_writer(),把x域图片转为tfrecords格式
  data_writer(FLAGS.X_input_dir, FLAGS.X_output_file)
  print("Convert Y data to tfrecords...")
  #开始调用函数:data_writer(),把x域图片转为tfrecords格式
  data_writer(FLAGS.Y_input_dir, FLAGS.Y_output_file)

if __name__ == '__main__':
  # 解析命令行参数,调用main函数
  tf.app.run()

2.train.py(训练模型)

import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool

#定义选用"0"号gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#主要是配置tf.Session的运行方式,GPU还是CPU
#参数代表,若指定的设备不存在,则允许tf自动分配设备
config = tf.ConfigProto(allow_soft_placement=True)
#在这里选择的是GPU的运行方式
config.gpu_options.allow_growth = True
#限制gpu的使用率,最大0.7
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
#确定sess会话的配置为以上配置
sess = tf.Session(config=config)

#定义全局变量,用FLAGS命名
#变量格式为("变量名",变量值,'描述')
FLAGS = tf.flags.FLAGS
#设置batch_size的大小,默认为1
tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
#定义训练图片的尺寸大小
tf.flags.DEFINE_integer('image_size', 512, 'image size, default: 256')
#默认使用lsgan:最小二乘法
tf.flags.DEFINE_bool('use_lsgan', True,'use lsgan (mean squared error) or cross entropy loss, default: True')
#定义规范化:instance(实例)
tf.flags.DEFINE_string('norm', 'instance','[instance, batch] use instance norm or batch norm, default: instance')
#定义正向循环损失的权重:10
tf.flags.DEFINE_integer('lambda1', 10,'weight for forward cycle loss (X->Y->X), default: 10')
#定义向后循环损失的权重:10
tf.flags.DEFINE_integer('lambda2', 10,'weight for backward cycle loss (Y->X->Y), default: 10')
#定义学习率:0.0002
tf.flags.DEFINE_float('learning_rate', 2e-4,'initial learning rate for Adam, default: 0.0002')
#定义优化器adam的一阶矩估计的指数衰减因子:0.5,加快收敛速度
tf.flags.DEFINE_float('beta1', 0.5,'momentum term of Adam, default: 0.5')
#存储先前生成的图像的图像缓冲区的大小
tf.flags.DEFINE_float('pool_size', 50,'size of image buffer that stores previously generated images, default: 50')
#定义卷积层中的过滤器数量
tf.flags.DEFINE_integer('ngf', 64,'number of gen filters in first conv layer, default: 64')
#定义X域的数据路径,即tfrecords文件的路径
tf.flags.DEFINE_string('X', 'face/512/real.tfrecords',
                       'X tfrecords file for training, default: img/apple.tfrecords')
#定义Y域的数据路径,即tfrecords文件的路径
tf.flags.DEFINE_string('Y', 'face/512/blur.tfrecords',
                       'Y tfrecords file for training, default: img/orange.tfrecords')
#定义训练模型的保存路径,可跟着上一次训练继续进行;选None则表示重新训练保存
tf.flags.DEFINE_string('load_model', 'checkpoints/20190510-0446/',
                        'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')


def train():
  #若果训练模型的路径不是空的,则模型路径等于FLAGS.load_model,接着上一次训练,否则建立新模型
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
  else:
    #获取当前时间
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    #以当前时间命名新模型文件夹,新建文件夹
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass
  #定义一个数据流的图grap
  graph = tf.Graph()
  #接下来的操作都在这个默认的图中
  with graph.as_default():
    #从model.py中引用类:CycleGAN,并且把参数传递进去
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda2,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    #调用函数cycle_gan.model(),获得各种损失值
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    #调用函数cycle_gan.optimize,获得learning_step---学习的迭代次数
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
    #把训练时的各种信息全部保存到磁盘,以便tensorboard显示---可视化学习
    summary_op = tf.summary.merge_all()
    #定义一个写入summary的目标文件,checkpoints_dir为写入文件地址    指定一个文件用来保存图graph
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    #添加操作以保存和恢复所有变量
    saver = tf.train.Saver()

  #定义会话
  with tf.Session(graph=graph) as sess:
    #如果模型不是空的
    if FLAGS.load_model is not None:
      #获取获取模型路径
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      #获取meat文件路径
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      #从meat文件中恢复计算图---保存模型时,已经把模型的图存放到meat文件中
      restore = tf.train.import_meta_graph(meta_graph_path)
      #恢复模型,加载模型到session中关联的graph中
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      #截取训练步长
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    #若是模型是空的,即新训练的
    else:
      #初始化变量,步长从0开始
      sess.run(tf.global_variables_initializer())
      step = 0
    #Session会话是多线程的
    #tf.train.Coordinator()创建一个线程管理器(协调器)对象,用来管理之后在Session中启动的所有线程
    coord = tf.train.Coordinator()
    #启动入队线程,函数返回线程ID的列表,一般情况下,系统有多少个核,就会启动多少个入队线程
    #(入队具体使用多少个线程在tf.train.batch中定义)
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      #从unit.py引入类:ImagePool,获取一个图像缓冲区——50张图片,然后随机打乱顺序
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        #计算fake_y、fake_x
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        #计算optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (sess.run(
          [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
          feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}))
        #add_summary函数向FileWriter对象的缓存中存放计算数据,并传入步长,使其变成曲线
        train_writer.add_summary(summary, step)
        #刷新缓冲区,即将缓冲区中的数据立刻写入文件,同时清空缓冲区,不需要是被动的等待输出缓冲区写入
        #一般情况下,文件关闭后会自动刷新缓冲区,但有时你需要在关闭前刷新它,这时就可以使用 flush() 方法。
        train_writer.flush()
        #以训练步长为10,来打印训练的日志信息
        if step % 10 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))
        #以训练步长为500,来保存训练模型
        if step % 500 == 0:
          #保存模型
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          #打印日志信息
          logging.info("Model saved in file: %s" % save_path)
        #步长+1
        step += 1
    #做异常处理---用户中断执行
    except KeyboardInterrupt:
      #打印信息
      logging.info('Interrupted')
      #主线程已完成任务,请求关闭处理入队操作的线程
      coord.request_stop()
    #做异常处理---常规错误
    except Exception as e:
      #主线程已完成任务,请求关闭处理入队操作的线程
      coord.request_stop(e)
    finally:
      #保存训练模型
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      #打印信息
      logging.info("Model saved in file: %s" % save_path)
      #主线程已完成任务,请求关闭处理入队操作的线程
      coord.request_stop()
      # 主线程等待所有线程关闭完毕再进入下一步
      coord.join(threads)

def main(unused_argv):
  #调用train()函数
  train()

if __name__ == '__main__':
  #设置日志打印信息,类似于print
  logging.basicConfig(level=logging.INFO)
  # 解析命令行参数
  tf.app.run()

3.export_graph.py(导出模型)

""" Freeze variables and convert 2 generator networks to 2 GraphDef files.
This makes file size smaller and can be used for inference in production.
An example of command-line usage is:
python export_graph.py --checkpoint_dir checkpoints/20170424-1152 \
                       --XtoY_model apple2orange.pb \
                       --YtoX_model orange2apple.pb \
                       --image_size 256
"""

import tensorflow as tf
import os
from tensorflow.python.tools.freeze_graph import freeze_graph
from model import CycleGAN
import utils

#定义全局变量,用FLAGS替代
FLAGS = tf.flags.FLAGS
#定义训练模型路径
tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints/20190510-0446/', 'checkpoints directory path')
#定义模型导出路径:x-y
tf.flags.DEFINE_string('XtoY_model', 'test/real2blur.pb', 'XtoY model name, default: apple2orange.pb')
#定义模型导出路径:y-x
tf.flags.DEFINE_string('YtoX_model', 'test/blur2real.pb', 'YtoX model name, default: orange2apple.pb')
#定义图片尺寸大小
tf.flags.DEFINE_integer('image_size', '1024', 'image size, default: 256')
#定义卷积层得过滤器数量
tf.flags.DEFINE_integer('ngf', 64,'number of gen filters in first conv layer, default: 64')
#定义规范化:instance(实例)
tf.flags.DEFINE_string('norm', 'instance','[instance, batch] use instance norm or batch norm, default: instance')

def export_graph(model_name, XtoY=True):
  #定义计算图
  graph = tf.Graph()
 #接下来的操作都在这个默认的图中
  with graph.as_default():
    #从model.py中引用类:CycleGAN,并且把参数传递进去
    cycle_gan = CycleGAN(ngf=FLAGS.ngf, norm=FLAGS.norm, image_size=FLAGS.image_size)
    #使用占位符,预定义输入图片
    input_image = tf.placeholder(tf.float32, shape=[FLAGS.image_size, FLAGS.image_size, 3], name='input_image')
    #调用model.py的类CycleGAN中的model()函数
    cycle_gan.model()
    #如果输入的是XtoY,则用G生成器,从x生成y
    if XtoY:
      output_image = cycle_gan.G.sample(tf.expand_dims(input_image, 0))   #tf.expand_dims:增加1维度
    #否则输入的是YtoX,则用F生成器,从y生成x
    else:
      output_image = cycle_gan.F.sample(tf.expand_dims(input_image, 0))   #tf.expand_dims:增加1维度
    #创建一个和output_image一样的节点
    output_image = tf.identity(output_image, name='output_image')
    #保存模型
    restore_saver = tf.train.Saver()
    export_saver = tf.train.Saver()
  #定义会话
  with tf.Session(graph=graph) as sess:
    #初始化变量
    sess.run(tf.global_variables_initializer())
    #获取最新模型路径
    latest_ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    #恢复最新模型
    restore_saver.restore(sess, latest_ckpt)
    #保存指定的节点,并将节点值保存为常数   与saver相似,但这个不必全部保存所有参数,可以选择性保存
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), [output_image.op.name])
    #保存计算图和结构---graph;  即把模型中的图和结构导出来了---.pb文件
    tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)

def main(unused_argv):
  print('Export XtoY model...')
  #调用export_graph函数导出模型
  export_graph(FLAGS.XtoY_model, XtoY=True)
  print('Export YtoX model...')
  #调用export_graph函数导出模型
  export_graph(FLAGS.YtoX_model, XtoY=False)

if __name__ == '__main__':
  # 解析命令行参数,调用main函数
  tf.app.run()

4.inference(测试)

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 256
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils

#定义全局变量,用FLAGS代替
FLAGS = tf.flags.FLAGS
# path='img/b/'
# for file in os.listdir(path):
#     print(file)
# tf.flags.DEFINE_string('model', 'pretrained/b2a.pb', 'model path (.pb)')
# tf.flags.DEFINE_string('input', 'img/a1.jpg', 'input image path (.jpg)')
# tf.flags.DEFINE_string('output', 'img/b1.jpg', 'output image path (.jpg)')
# tf.flags.DEFINE_integer('image_size', '100', 'image size, default: 256')

def inference(img,output,img_size,model):
  #定义计算的图
  graph = tf.Graph()
  with graph.as_default():
    #相当于f = tf.gfile.FastGFile(img, 'rb'),换个简单的名字
    with tf.gfile.FastGFile(img, 'rb') as f:
      #读取图片
      image_data = f.read()
      #把图片解码成三维矩阵
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      #resize图片
      input_image = tf.image.resize_images(input_image, size=(img_size, img_size))
      #调用units.py的convert2float函数,把int型的图像数据转换为浮动张量   [0-255]转为[0-1]
      input_image = utils.convert2float(input_image)
      #改变input_image的张量
      input_image.set_shape([img_size, img_size, 3])

    #读取模型
    with tf.gfile.FastGFile(model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name=output)

  #定义会话
  with tf.Session(graph=graph) as sess:
    #执行output_image这一动作
    generated = output_image.eval()
    #打开output
    with open(output, 'wb') as f:
      #把generated得到的值写入output
      f.write(generated)

def main(unused_argv):
    #定义图片尺寸
    img_size = 1024
    #模型路径
    model = 'pretrained/test/blur2real.pb'
    #用于测试的原图片路径---批量
    path1='img/pic/1024/'
    #测试生成的图片保存路径
    path2 = 'img/img/512/'
    #遍历原图片,单张生成测试图片
    for file in os.listdir(path1):
        #输入文件路径=输入文件夹路径+文件名
        img=path1+file
        #输出的文件路径=输出文件夹路径+文件名
        output=path2+file
        #调用inference函数,生成测试图片
        inference(img,output,img_size,model)

if __name__ == '__main__':
  #解析全局变量,并运行main()函数
  tf.app.run()

猜你喜欢

转载自blog.csdn.net/gm_Ergou/article/details/93765378