CycleGAN(三)代码概览

版权声明:转载注明出处:邢翔瑞的技术博客https://blog.csdn.net/weixin_36474809 https://blog.csdn.net/weixin_36474809/article/details/88823295

目的:大致看懂cycleGAN代码结构

参考https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/overview.md

目录

一、概览

train.py

test.py

data文件夹

models文件夹

options文件夹

util文件夹

二、train.py

三、test.py

四、data文件夹

4.1 __init__.py

4.2 base_dataset.py

4.3 image_folder.py

4.4 template_dataset.py

4.5 aligned_dataset.py

4.6 unaligned_dataset.py

 4.7 single、clolorization dataset

五、models

六、options and util


一、概览

train.py

用于模型训练

--model: e.g., pix2pix, cyclegan, colorization) and

different datasets (with option --dataset_mode: e.g., aligned, unaligned, single, colorization

test.py

用于模型测试

data文件夹

包含关于所有数据加载数据处理的程序。

models文件夹

模型相关的objective functions, optimizations, and network architectures.

options文件夹

训练,测试以及相关模型的选项

util文件夹

相关帮助函数的杂项汇总

二、train.py

train获取数据(data文件夹),创建模型(model),循环epoch然后在epoch内更新参数,存储网络

三、test.py

读取数据,创建模型,然后将数据送入模型进行test

四、data文件夹

包含着模型加载与处理的程序。

4.1 __init__.py

用于给train和test过程生成数据集

from data import create_dataset
dataset = create_dataset(opt) to create a dataset given the option opt.

4.2 base_dataset.py

https://docs.python.org/3/library/abc.html

  • base_dataset.py implements an abstract base class (ABC) for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.

用于运用abstract base class abc。

4.3 image_folder.py

pytorch默认只从文件夹中读文件,作者可以从文件夹和子文件夹中读文件。

4.4 template_dataset.py

创建一个数据集的模板,以及详细的描述

4.5 aligned_dataset.py

用于加载样本对(主要用于pix2pix,对于我们的cycleGAN并无太大作用)

4.6 unaligned_dataset.py

用于unpaired 数据集,用于cycleGAN,训练时trainA 和trainB 中应该放入domainA和domainB中的东西,test时也是这样。

 4.7 single、clolorization dataset

  • image_folder.py implements an image folder class. We modify the official PyTorch image folder code so that this class can load images from both the current directory and its subdirectories.
  • template_dataset.py provides a dataset template with detailed documentation. Check out this file if you plan to implement your own dataset.
  • aligned_dataset.py includes a dataset class that can load image pairs. It assumes a single image directory /path/to/data/train, which contains image pairs in the form of {A,B}. See here on how to prepare aligned datasets. During test time, you need to prepare a directory /path/to/data/test as test data.
  • unaligned_dataset.py includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A /path/to/data/trainA and from domain B /path/to/data/trainB respectively. Then you can train the model with the dataset flag --dataroot /path/to/data. Similarly, you need to prepare two directories /path/to/data/testA and /path/to/data/testB during test time.
  • single_dataset.py includes a dataset class that can load a set of single images specified by the path --dataroot /path/to/data. It can be used for generating CycleGAN results only for one side with the model option -model test.
  • colorization_dataset.py implements a dataset class that can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs in Lab color space. It is required by pix2pix-based colorization model (--model colorization).

五、models

modules related to objective functions, optimizations, and network architectures.

更改models的顺序:

To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
    -- <set_input>:                     unpack data from dataset and apply preprocessing.
    -- <forward>:                       produce intermediate results.
    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.

In the function <__init__>, you need to define four lists:
    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
    -- self.model_names (str list):         define networks used in our training.
    -- self.visual_names (str list):        specify the images that you want to display and save.
    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.

六、options and util

猜你喜欢

转载自blog.csdn.net/weixin_36474809/article/details/88823295