NVlabs/noise2noise代码(三)网络训练代码解析

版权声明:转载注明出处:邢翔瑞的技术博客https://blog.csdn.net/weixin_36474809 https://blog.csdn.net/weixin_36474809/article/details/87919252

未完

目录

一、更改迭代次数

1.1定义位置

train_config初始化位置

EasyDict定义位置

1.2 更改迭代次数的方法

二、网络结构

2.1 原始网络结构与代码解析

2.2 训练时autoencoder的调用

三、训练函数嵌套关系

3.1 config.py 到submit_run

3.2 submit.py 中的run_wrapper

3.3 调用train.py

3.4打印出信息

四、训练函数

4.1 输入参数及类型


一、更改迭代次数

原始迭代次数较大,所以运行一次耗时较久,不易调试,我们将迭代次数改小以便调试。直接在main下面的def train(args)中更改就可以。

1.1定义位置

config.py之中,这里将其值改小就可以取得减少程序运行时间的效果。

if __name__ == "__main__":
    def train(args):
        if args:
            n2n = args.noise2noise if 'noise2noise' in args else True
            train_config.noise2noise = n2n
            if 'long_train' in args and args.long_train:
                #train_config.iteration_count = 500000
                train_config.iteration_count = 500
                #train_config.eval_interval = 5000
                train_config.eval_interval = 50
                train_config.ramp_down_perc = 0.5

train_config初始化位置

train_config = dnnlib.EasyDict(
    iteration_count=300000,
    eval_interval=1000,
    minibatch_size=4,
    run_func_name="train.train",
    learning_rate=0.0003,
    ramp_down_perc=0.3,
    noise=gaussian_noise_config,
#    noise=poisson_noise_config,
    noise2noise=True,
    train_tfrecords='datasets/imagenet_val_raw.tfrecords',
    validation_config=default_validation_config
)

可以看出将其定义为dnnlib.EasyDict格式的数据。类似于字典,但比字典更易使用。

EasyDict定义位置

train_config与validate_config都运用了dnnlib.EasyDict作为初始化,此class的位置为dnnlib/util.py

class EasyDict(dict):
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""

    def __getattr__(self, name: str) -> Any:
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name: str, value: Any) -> None:
        self[name] = value

    def __delattr__(self, name: str) -> None:
        del self[name]

1.2 更改迭代次数的方法

迭代次数分别改为,评估与输出的interval,和输出的斜率。

                train_config.iteration_count = 100
                train_config.eval_interval = 10
                train_config.ramp_down_perc = 0.5

迭代次数设为500时,最终PSNR=27.34,时间为2m28s

迭代次数为100时,最终PSNR=23.42,时间为1m48s

二、网络结构

2.1 原始网络结构与代码解析

原始网络可以看作一个unet

    skips = [x]

    n = x
    n = conv_lr('enc_conv0', n, 48)
    n = conv_lr('enc_conv1', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv2', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv3', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv4', n, 48)
    n = maxpool2d(n)
    skips.append(n)

    n = conv_lr('enc_conv5', n, 48)
    n = maxpool2d(n)
    n = conv_lr('enc_conv6', n, 48)

    #-----------------------------------------------
    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv5', n, 96)
    n = conv_lr('dec_conv5b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv4', n, 96)
    n = conv_lr('dec_conv4b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv3', n, 96)
    n = conv_lr('dec_conv3b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv2', n, 96)
    n = conv_lr('dec_conv2b', n, 96)

    n = upscale2d(n)
    n = tf.concat([n, skips.pop()], axis=1)
    n = conv_lr('dec_conv1a', n, 64)
    n = conv_lr('dec_conv1b', n, 32)

    n = conv('dec_conv1', n, 3, gain=1.0)

结构如下,相应数值和尺寸需要按照上面代码进行变化。

基本可以确定autoencoder之中确定的网络结构。

2.2 训练时autoencoder的调用

直接在autoencoder之中加输出的信息

(n2n) jcx@smart-dsp:~/Desktop/NVlabs_noise2noise$ CUDA_VISIBLE_DEVICES=1 python config.py train --train-tfrecords=datasets/part_bsd300.tfrecords --long-train=false --noise=gaussian
----------train in config.py
----------Iteration count is 100 and eval_interval is 10
{'iteration_count': 100, 'eval_interval': 10, 'minibatch_size': 4, 'run_func_name': 'train.train', 'learning_rate': 0.0003, 'ramp_down_perc': 0.5, 'noise': {'func_name': 'train.AugmentGaussian', 'train_stddev_rng_range': (0.0, 50.0), 'validation_stddev': 25.0}, 'noise2noise': True, 'train_tfrecords': 'datasets/part_bsd300.tfrecords', 'validation_config': {'dataset_dir': 'datasets/kodak'}}
----------submit_run in submit.py
----------submit_config.submit_target in {SubmitTarget.LOCAL},create new dir to run
Creating the run dir: results/00006-autoencoder-n2n
----------_populate_run_dir function in submit.py. Copying files to the run dir
----------run_wrapper function in submit.py
dnnlib: Running train.train() on localhost...
----------train in train.py
----------Setting up dataset source from datasets/part_bsd300.tfrecords
----------net = tflib.Network(**config.net_config) in train.py
-------------autoencoder in network.py

autoencoder                 Params      OutputShape             WeightShape
---                         ---         ---                     ---
x                           -           (?, 3, 256, 256)        -
enc_conv0                   1344        (?, 48, 256, 256)       (3, 3, 3, 48)
enc_conv1                   20784       (?, 48, 256, 256)       (3, 3, 48, 48)
MaxPool                     -           (?, 48, 128, 128)       -
enc_conv2                   20784       (?, 48, 128, 128)       (3, 3, 48, 48)
MaxPool_1                   -           (?, 48, 64, 64)         -
enc_conv3                   20784       (?, 48, 64, 64)         (3, 3, 48, 48)
MaxPool_2                   -           (?, 48, 32, 32)         -
enc_conv4                   20784       (?, 48, 32, 32)         (3, 3, 48, 48)
MaxPool_3                   -           (?, 48, 16, 16)         -
enc_conv5                   20784       (?, 48, 16, 16)         (3, 3, 48, 48)
MaxPool_4                   -           (?, 48, 8, 8)           -
enc_conv6                   20784       (?, 48, 8, 8)           (3, 3, 48, 48)
Upscale2D                   -           (?, 48, 16, 16)         -
dec_conv5                   83040       (?, 96, 16, 16)         (3, 3, 96, 96)
dec_conv5b                  83040       (?, 96, 16, 16)         (3, 3, 96, 96)
Upscale2D_1                 -           (?, 96, 32, 32)         -
dec_conv4                   124512      (?, 96, 32, 32)         (3, 3, 144, 96)
dec_conv4b                  83040       (?, 96, 32, 32)         (3, 3, 96, 96)
Upscale2D_2                 -           (?, 96, 64, 64)         -
dec_conv3                   124512      (?, 96, 64, 64)         (3, 3, 144, 96)
dec_conv3b                  83040       (?, 96, 64, 64)         (3, 3, 96, 96)
Upscale2D_3                 -           (?, 96, 128, 128)       -
dec_conv2                   124512      (?, 96, 128, 128)       (3, 3, 144, 96)
dec_conv2b                  83040       (?, 96, 128, 128)       (3, 3, 96, 96)
Upscale2D_4                 -           (?, 96, 256, 256)       -
dec_conv1a                  57088       (?, 64, 256, 256)       (3, 3, 99, 64)
dec_conv1b                  18464       (?, 32, 256, 256)       (3, 3, 64, 32)
dec_conv1                   867         (?, 3, 256, 256)        (3, 3, 32, 3)
---                         ---         ---                     ---
Total                       991203

----------Building TensorFlow graph...
-------------autoencoder in network.py
----------train_step = opt.apply_updates() in train.py
----------Training...
-------------autoencoder in network.py
-------------autoencoder in network.py
-------------autoencoder in network.py
Average PSNR: 10.03
iter 0          time 5s           sec/eval 0.0     sec/iter 0.00    maintenance 5.1
Average PSNR: 20.21
iter 10         time 1m 13s       sec/eval 3.5     sec/iter 0.35    maintenance 64.6
Average PSNR: 21.51
iter 20         time 1m 24s       sec/eval 2.1     sec/iter 0.21    maintenance 8.5
Average PSNR: 22.28
iter 30         time 1m 34s       sec/eval 2.1     sec/iter 0.21    maintenance 8.5
Average PSNR: 22.71
iter 40         time 1m 45s       sec/eval 2.0     sec/iter 0.20    maintenance 8.2
Average PSNR: 22.94
iter 50         time 1m 55s       sec/eval 2.1     sec/iter 0.21    maintenance 8.5
Average PSNR: 23.36
iter 60         time 2m 06s       sec/eval 2.1     sec/iter 0.21    maintenance 8.5
Average PSNR: 23.31
iter 70         time 2m 16s       sec/eval 2.0     sec/iter 0.20    maintenance 8.4
Average PSNR: 23.38
iter 80         time 2m 27s       sec/eval 1.9     sec/iter 0.19    maintenance 8.7
Average PSNR: 23.47
iter 90         time 2m 37s       sec/eval 2.0     sec/iter 0.20    maintenance 8.2
Elapsed time: 2m 47s
dnnlib: Finished train.train() in 2m 48s.
----------try in run wrapper finally handle _finished.txt

三、训练函数嵌套关系

3.1 config.py 到submit_run

config.py之中是主函数,训练过程中,函数的作用是加入相应参数之后,运行train函数

    submit_config.run_desc = desc + args.desc
    if args.run_dir_root is not None:
        submit_config.run_dir_root = args.run_dir_root
    if args.command is not None:
        args.func(args)
    else:
        # Train if no subcommand was given
        train(args)

train函数中关键语句为

dnnlib.submission.submit.submit_run(submit_config, **train_config)

其中输入参数,train_config与submit_config就是上面设置的相应参数.其中,train_config是一系列与训练相关的代码,submit_config在submit.py之中,用于实现将运行目录及结果写入等等操作。

#train config
train_config = dnnlib.EasyDict(
    iteration_count=300000,
    eval_interval=1000,
    minibatch_size=4,
    run_func_name="train.train",
    learning_rate=0.0003,
    ramp_down_perc=0.3,
    noise=gaussian_noise_config,
#    noise=poisson_noise_config,
    noise2noise=True,
    train_tfrecords='datasets/imagenet_val_raw.tfrecords',
    validation_config=default_validation_config
)

#submit config
submit_config = dnnlib.SubmitConfig()
submit_config.run_dir_root = 'results'
submit_config.run_dir_ignore += ['datasets', 'results']

desc = "autoencoder"

3.2 submit.py 中的run_wrapper

submit_run之中,创建相应文件夹,在文件夹之中运行该运行的函数,核心语句为run_wrapper函数

run_wrapper函数中用于运行封装好的函数

核心语句:

    try:
        print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
        start_time = time.time()
        util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
        print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
    except:
        if is_local:
            raise
        else:
            traceback.print_exc()

            log_src = os.path.join(submit_config.run_dir, "log.txt")
            log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
            shutil.copyfile(log_src, log_dst)
    finally:
        open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()

其中,就是运用call_func_by_name函数进行函数的调用与实现

util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)

3.3 调用train.py

util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)

通过call_func_by_name这个函数调用相应的函数,

def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
    """Finds the python object with the given name and calls it as a function."""
    assert func_name is not None
    func_obj = get_obj_by_name(func_name)
    assert callable(func_obj)
    return func_obj(*args, **kwargs)

3.4打印出信息

通过打印出信息验证了嵌套及调用关系

----------train in config.py
----------Iteration count is 100 and eval_interval is 10
{'iteration_count': 100, 'eval_interval': 10, 'minibatch_size': 4, 'run_func_name': 'train.train', 'learning_rate': 0.0003, 'ramp_down_perc': 0.5, 'noise': {'func_name': 'train.AugmentGaussian', 'train_stddev_rng_range': (0.0, 50.0), 'validation_stddev': 25.0}, 'noise2noise': True, 'train_tfrecords': 'datasets/part_bsd300.tfrecords', 'validation_config': {'dataset_dir': 'datasets/kodak'}}
----------submit_run in submit.py
----------submit_config.submit_target in {SubmitTarget.LOCAL},create new dir to run
Creating the run dir: results/00503-autoencoder-n2n
----------_populate_run_dir function in submit.py. Copying files to the run dir
----------run_wrapper function in submit.py
dnnlib: Running train.train() on localhost...
---------------train function in train.py

四、训练函数

train.py用于进行函数的训练

4.1 输入参数及类型

def train(
    submit_config: dnnlib.SubmitConfig,
    iteration_count: int,
    eval_interval: int,
    minibatch_size: int,
    learning_rate: float,
    ramp_down_perc: float,
    noise: dict,
    validation_config: dict,
    train_tfrecords: str,
    noise2noise: bool):

猜你喜欢

转载自blog.csdn.net/weixin_36474809/article/details/87919252