风格迁移1-04:Liquid Warping GAN(Impersonator)-源码无死角解析(1)-训练代码总览

以下链接是个人关于Liquid Warping GAN(Impersonator)-姿态迁移,所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
风格迁移1-00:Liquid Warping GAN(Impersonator)-目录-史上最新无死角讲解

训练脚本参数

在训练之前,我们先来看看训练脚本,即scripts/train_iPER.sh,下面是简单的一些注释:

#! /bin/bash

# basic configs
#gpu_ids=0,1     # if using multi-gpus, increasing the batch_size
gpu_ids=0

# dataset configs
dataset_model=iPER  # use iPER dataset  指定使用数据集的格式
data_dir=../2.Dataset/ # need to be replaced!!!!! 指定数据集的根目录
images_folder=images # 图像的目录
smpls_folder=smpls #样本目录,其主要保存关键点和形状的信息
train_ids_file=train.txt # 参与训练的ids
test_ids_file=val.txt # 用于测试的id

# saving configs,配置保存模型的路径和文件名
checkpoints_dir=../2.Dataset/ckpts_models   # directory to save models, need to be replaced!!!!!
name=exp_iPER   # the directory is ${checkpoints_dir}/name, which is used to save the checkpoints.

# model configs,选择训练的模式,默认为模仿模式
model=impersonator_trainer
gen_name=impersonator
image_size=256

# training configs,论文中相关的参数
load_path="None"
batch_size=1
lambda_rec=10.0
lambda_tsf=10.0
lambda_face=5.0
lambda_style=0.0
lambda_mask=1.0
#lambda_mask=2.5
lambda_mask_smooth=1.0
# 在指定epoch范围区间,学习率保持不变
nepochs_no_decay=5  # fixing learning rate when epoch ranges in [0, 5]
# 在指定epoch范围区间,每個epoch学习率会衰减
nepochs_decay=25    # decreasing the learning rate when epoch ranges in [6, 25+5]




python train.py --gpu_ids ${gpu_ids}        \
    --data_dir  ${data_dir}                 \
    --images_folder    ${images_folder}     \
    --smpls_folder     ${smpls_folder}      \
    --checkpoints_dir  ${checkpoints_dir}   \
    --train_ids_file   ${train_ids_file}    \
    --test_ids_file    ${test_ids_file}     \
    --load_path        ${load_path}         \
    --model            ${model}             \
    --gen_name         ${gen_name}          \
    --name             ${name}              \
    --dataset_mode     ${dataset_model}     \
    --image_size       ${image_size}        \
    --batch_size       ${batch_size}        \
    --lambda_face      ${lambda_face}       \
    --lambda_tsf       ${lambda_tsf}        \
    --lambda_style     ${lambda_style}      \
    --lambda_rec       ${lambda_rec}         \
    --lambda_mask      ${lambda_mask}       \
    --lambda_mask_smooth  ${lambda_mask_smooth} \
    --nepochs_no_decay ${nepochs_no_decay}  --nepochs_decay ${nepochs_decay}  \
    --mask_bce     --use_vgg       --use_face

train.py代码注释

下面是训练代码的总体注释,很多细节没有给出,如可视化,以及打印的在终端的信息,究竟是什么信息,不过没有关系,后续在分析代码的过程中,我们肯定能找到我们想要的答案,注释如下:

import time
from options.train_options import TrainOptions
from data.custom_dataset_data_loader import CustomDatasetDataLoader
from models.models import ModelsFactory
from utils.tb_visualizer import TBVisualizer
from collections import OrderedDict


class Train(object):
    def __init__(self):
        # 对命令行参数进行解析
        self._opt = TrainOptions().parse()
        # 创建训练,测试数据迭代对象
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        # 加载训练以及迭代数据
        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        # 获取训练以及迭代数据的长度,即每个epoch需要迭代多少次
        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train video clips = %d' % self._dataset_train_size)
        print('#test video clips = %d' % self._dataset_test_size)

        # 根据模式名字创建模型,默认使用impersonator模式进行训练,self._opt.model= impersonator_trainer
        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        # tensorflow的可视化
        self._tb_visualizer = TBVisualizer(self._opt)

        # 进行训练
        self._train()

    def _train(self):
        # 计算总的迭代步数
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        # 计算每个epoch迭代的次数
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size

        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        # 循环进行训练
        for i_epoch in range(self._opt.load_epoch + 1, self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()
            # train epoch
            self._train_epoch(i_epoch)

            # save model,每个epoch训练完,保存一次模型
            print('saving the model at the end of epoch %d, iters %d' % (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info,打印相关的信息,如训练了多少epoch,消耗了多少时间等等
            time_epoch = time.time() - epoch_start_time
            print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                  (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay, time_epoch,
                   time_epoch / 60, time_epoch / 3600))

            # update learning rate,如果需要更新学习率,则更新学习率
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()


    def _train_epoch(self, i_epoch):
        """
        训练一个epoch的细节
        :param i_epoch:
        :return:
        """
        epoch_iter = 0
        # 模型设置为训练模式
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            # 记录迭代的开始时间
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time() - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time() - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model,设置模型输入的
            self._model.set_input(train_batch)
            trainable = ((i_train_batch+1) % self._opt.train_G_every_n_iterations == 0) or do_visuals
            # 对参数进行进行优化,也就是反向传播,do_visuals表示指定该次是否进行可视化
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals, trainable=trainable)

            # update epoch info,更新步数
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal,终端打印训练信息
            if do_print_terminal:

                self._display_terminal(iter_start_time, i_epoch, i_train_batch, do_visuals)
                self._last_print_time = time.time()

            # display visualizer,可视化显示
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' % (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag):
        """
        终端打印训练想过信息,并且进行可视化
        """
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch, self._iters_per_epoch, errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        """
        训练可视化
        """
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        """
        评估可视化
        """
        val_start_time = time.time()

        # set model to eval,设置模型为评估模式
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()

        # 迭代获取数据,进行评估
        for i_val_batch, val_batch in enumerate(self._dataset_test):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.items():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors:
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors, total_steps, is_train=False)
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=False)

        # set model back to train
        self._model.set_train()


if __name__ == "__main__":
    Train()

还是特别简单,基本都是这个套路:
1.加载训练测试数据集迭代器
2.构建网络模型
3.迭代训练
4.模型评估保存
好了,总体的结构就简单的介绍到这里,下小结为大家开始讲解代码的每一个细节。

发布了221 篇原创文章 · 获赞 698 · 访问量 12万+

猜你喜欢

转载自blog.csdn.net/weixin_43013761/article/details/103934017