mmdetection入门介绍-train.py解析

四、train.py解析

同样,上面有单GPU测试和多GPU测试,其实上面的测试是由训练导致的。

单GPU训练

python tools/train.py ${CONFIG_FILE}

如果要在命令中指定工作目录,则可以添加参数–work_dir $ {YOUR_WORK_DIR}。如果没有指定的话就用的就是默认在config/**.py下的work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_voc0712’下
参数解释:

  • –validate(强烈建议):在训练过程中,每隔k个时期(默认值是1,可以像这样修改)执行评估。
  • –work_dir $ {WORK_DIR}:覆盖配置文件中指定的工作目录。
  • –resume_from $ {CHECKPOINT_FILE}:从先前的检查点文件恢复。
  • –gpus:是指使用的GPU数量,默认值为1颗;–launcher:是指分布式训练的任务启动器(job launcher),默认值为none表示不进行分布式训练;

其中有几点需要说明的是:

–validate只支持多GPU训练,不支持单GPU训练,甚至包括后面会遇到的workflow = [(‘train’, 1)(‘val’,1)],即训练一次验证一次对单个GPU的场景也是不适用的;

resume_from和load_from之间的区别:resume_from同时加载模型权重和优化器状态,并且epoch也从指定的检查点继承。它通常用于恢复意外中断的训练过程。 load_from仅加载模型权重,并且训练时期从0开始。通常用于微调。

多GPU训练

./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

我们可以看一下dist_train.sh的内容

#!/usr/bin/env bash

PYTHON=${PYTHON:-"python"}

CONFIG=$1
GPUS=$2

$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}

其实还是调用train.py,不过这里配置了–launch来启动分布式训练。

这里需要知道一下,关于学习率的一个计算过程:
配置文件中的默认学习率是8个GPU和2 img/gpu(batch_size= 8 * 2 =16)。根据线性缩放规则,如果使用不同的GPU或img/gpu,则需要按照batch_size大小设置学习率,例如,对于4个GPU,lr = 0.01 * 2 img/gpu;对于16个GPU,lr = 0.08 * 4 img/gpu。

知道如何测试别人或者已下好的模型后,就可以转到训练模型,首先还是打开train.py的主要功能

def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    # 在图片输入尺度固定时开启,可以加速.一般都是关的,只有在固定尺度的网络如SSD512中才开启
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        # 创建工作目录存放训练文件,如果不设置,会自动按照py配置文件生成对应的目录
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        # 断点继续训练的权值文件
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # 搭建模型
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    # 将训练配置传入
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)  
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

if __name__ == '__main__':
    main()

同样,还是参数的读取

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work_dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume_from', help='the checkpoint file to resume from')
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--gpus',
        type=int,
        default=1,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args

设置训练命令后,cfg就会读取相关的配置信息

    args = parse_args()
    cfg = Config.fromfile(args.config)

可以看到训练模型进行了一定的检查

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

这里面我们其实可以看到学习率args.autoscale_lr的设置,这里也明确说了是linear scaling rule。

然后从配置文件中读取信息,设置模型和数据集

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]

设置验证集

    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))

设置checkpoint

    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)

设置模型信息

    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

我们看一下这个train_detector函数的定义

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)

    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
    else:
        _non_dist_train(model, dataset, cfg, validate=validate)

可以看到,模型分分布式训练和非分布式训练

可以看到分布式训练配置

def _dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
        for ds in dataset
    ]
    # put model on gpus
    model = MMDistributedDataParallel(model.cuda())

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
        val_dataset_cfg = cfg.data.val
        eval_cfg = cfg.get('evaluation', {})
        if isinstance(model.module, RPN):
            # TODO: implement recall hooks for other datasets
            runner.register_hook(
                CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
        else:
            dataset_type = DATASETS.get(val_dataset_cfg.type)
            if issubclass(dataset_type, datasets.CocoDataset):
                runner.register_hook(
                    CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
            else:
                runner.register_hook(
                    DistEvalmAPHook(val_dataset_cfg, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

非分布式训练

def _non_dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False) for ds in dataset
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

我们可以看到,非分布式训练时没有validate的(这里有个想法,为什么非分布式训练没有加val?如果我把代码强行加进去会怎么样?)

其他参考:
这里是mmdetection入门介绍 前言 部分
这里是mmdetection入门介绍 test.py解析 部分
这里是mmdetection入门介绍 train.py解析 部分
这里是mmdetection入门介绍 模型解析 部分

发布了80 篇原创文章 · 获赞 86 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/klaus_x/article/details/103831939
今日推荐