MMDetection学习笔记(五):整体构建流程与代码解析

MMDetection学习笔记(五):整体构建流程与代码解析

写在前面:建议先看完博主的另一篇博客核心组件分析,再去理解整个代码逻辑,结合代码反复阅读,抓住其中面向对象编程的核心思想,祝顺利,欢迎留言评论,博主会定期解答!

整体构建流程

按照数据流过程,训练流程可以简单总结为:

  1. 获取config配置并初始化各种类的实例化,通过Runner进行全生命周期管理:
    (1)Model类初始化,并根据是否多卡训练,进一步对Model类的上层进一步封装,若是分布式(单机多卡或多机多卡)训练,则初始化MMDistributedDataParallel类,若单机训练,则初始化MMDataParallel类;这两个类不仅可以处理 DataContainer 对象,还额外实现了 train_step() 和 val_step() 两个函数,可以被 Runner 调用。
    (2)Dataset类初始化,在迭代输出数据的时候需要通过数据 Pipeline 对数据进行各种处理,最典型的处理流是训练中的数据增强操作,测试中的数据预处理等等;将 Sampler(通过 Sampler 采样器可以控制 Dataset 输出的数据顺序,最常用的是随机采样器 RandomSampler。由于 Dataset 中输出的图片大小不一样,为了尽可能减少后续组成 batch 时 pad 的像素个数,MMDetection 引入了分组采样器 GroupSampler 和 DistributedGroupSampler,相当于在 RandomSampler 基础上额外新增了根据图片宽高比进行 group 功能)和 Dataset 都输入给 DataLoader,然后通过 DataLoader 输出已组成 batch 的数据,作为 Model 的输入;
    (3)Runner类初始化,它负责管理每一个epoch和iteration的train或val,还负责调用hook实现功能扩展,从而方便地获取、修改和拦截任何生命周期数据流。
    (4)Logger、Hook等类的初始化。
  2. Model 运行,输出 loss 以及其他一些信息,会通过 logger 进行保存或者可视化;
  3. 根据loss计算梯度并更新权重;

而测试流程就比较简单了,直接对 DataLoader 输出的数据进行前向推理即可,还原到最终原图尺度过程也是在 Model 中完成。

以上就是 MMDetection 框架整体训练和测试抽象流程,上图不仅仅反映了训练和测试数据流,而且还包括了模块和模块之间的调用关系。对于训练而言,最核心部分应该是 Runner,理解了 Runner 的运行流程,也就理解了整个 MMDetection 数据流。

代码解析

训练流程

1、初始化配置、logger、model、datasets、runner等,调用runner.run()函数

#=================== tools/train.py ==================
# 1.初始化配置
cfg = Config.fromfile(args.config)

# 2.判断是否为分布式训练模式

# 3.初始化 logger
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

# 4.收集运行环境并且打印,方便排查硬件和软件相关问题
env_info_dict = collect_env()

# 5.初始化 model
model = build_detector(cfg.model, ...)

# 6.初始化 datasets

#=================== mmdet/apis/train.py ==================
# 1.初始化 data_loaders ,内部会初始化 GroupSampler、DistributedSampler、DistributedGroupSampler
data_loader = DataLoader(dataset,...)

# 2.基于是否使用分布式训练,初始化对应的 DataParallel
if distributed:
  model = MMDistributedDataParallel(...)
else:
  model = MMDataParallel(...)

# 3.初始化 runner
runner = EpochBasedRunner(...)

# 4.注册必备 hook
runner.register_training_hooks(cfg.lr_config, optimizer_config,
                               cfg.checkpoint_config, cfg.log_config,
                               cfg.get('momentum_config', None))

# 5.如果需要 val,则还需要注册 EvalHook           
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

# 6.注册用户自定义 hook
runner.register_hook(hook, priority=priority)

# 7.权重恢复和加载
if cfg.resume_from:
    runner.resume(cfg.resume_from)
elif cfg.load_from:
    runner.load_checkpoint(cfg.load_from)

# 8.运行,开始训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

runner 对象内部的 run 方式是一个通用方法,可以运行任何 workflow,目前常用的主要是 train 和 val。

  • 当配置为:workflow = [(‘train’, 1)],表示仅仅进行 train workflow,也就是迭代训练
  • 当配置为:workflow = [(‘train’, n),(‘val’, 1)],表示先进行 n 个 epoch 的训练,然后再进行1个 epoch 的验证,然后循环往复,如果写成 [(‘val’, 1),(‘train’, n)] 表示先进行验证,然后才开始训练

2、调用runner中的 train() 或者 val()
当进入对应的 workflow,则会调用 runner 里面的 train() 或者 val(),表示进行一次 epoch 迭代,如下所示:

def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    # 在每一次epoch训练前调用hook
    self.call_hook('before_train_epoch')
    for i, data_batch in enumerate(self.data_loader):
    	# 在每一次iter训练前调用hook
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True)
        # 在每一次iter训练后调用hook
        self.call_hook('after_train_iter')
	# 在每一次epoch训练后调用hook
    self.call_hook('after_train_epoch')


def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    # 在每一次epoch验证前调用hook
    self.call_hook('before_val_epoch')
    for i, data_batch in enumerate(self.data_loader):
    	# 在每一次iter验证前调用hook
        self.call_hook('before_val_iter')
        with torch.no_grad():
            self.run_iter(data_batch, train_mode=False)
        # 在每一次iter验证后调用hook
        self.call_hook('after_val_iter')
   	# 在每一次epoch验证后调用hook
    self.call_hook('after_val_epoch')

在每一个epoch不断迭代循环,实现每一个iteration,核心函数实际上是 self.run_iter(),如下:

def run_iter(self, data_batch, train_mode, **kwargs):
    if train_mode:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.train_step(data_batch,...)
    else:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.val_step(data_batch,...)

    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'],...)
    self.outputs = outputs

3、runner 中调用 train_step 或者 val_step

#=================== mmcv/runner/epoch_based_runner.py ==================
if train_mode:
    outputs = self.model.train_step(data_batch,...)
else:
    outputs = self.model.val_step(data_batch,...)

实际上,首先会调用 DataParallel 中的 train_step 或者 val_step ,其具体调用流程为:

# 非分布式训练
#=================== mmcv/parallel/data_parallel.py/MMDataParallel ==================
def train_step(self, *inputs, **kwargs):
    if not self.device_ids:
        inputs, kwargs = self.scatter(inputs, kwargs, [-1])
        # 此时才是调用 model 本身的 train_step
        return self.module.train_step(*inputs, **kwargs)
    # 单 gpu 模式
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 train_step
    return self.module.train_step(*inputs[0], **kwargs[0])

# val_step 也是的一样逻辑
def val_step(self, *inputs, **kwargs):
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 val_step
    return self.module.val_step(*inputs[0], **kwargs[0])

可以发现,在调用 model 本身的 train_step 前,需要额外调用 scatter 函数,该函数的作用是处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。

如果是分布式训练,则调用的实际上是 mmcv/parallel/distributed.py/MMDistributedDataParallel,最终调用的依然是 model 本身的 train_step 或者 val_step。

4、调用 model 中的 train_step 或者 val_step

#=================== mmdet/models/detectors/base.py/BaseDetector ==================
def train_step(self, data, optimizer):
    # 实例():调用__call__()函数,在函数内部会调用本类自身的 forward 方法
    losses = self(**data)
    # 解析 loss
    loss, log_vars = self._parse_losses(losses)
    # 返回字典对象
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
    return outputs

def forward(self, img, img_metas, return_loss=True, **kwargs):
    if return_loss:
        # 训练模式
        return self.forward_train(img, img_metas, **kwargs)
    else:
        # 测试模式
        return self.forward_test(img, img_metas, **kwargs)

forward_train 和 forward_test 需要在不同的算法子类中实现,输出是 Loss 或者 预测结果。

5、调用子类中的 forward_train 方法

目前提供了两个具体子类,TwoStageDetector 和 SingleStageDetector ,用于实现 two-stage 和 single-stage 算法。

对于 TwoStageDetector 而言,其核心逻辑是:

#============= mmdet/models/detectors/two_stage.py/TwoStageDetector ============
def forward_train(...):
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    losses = dict()
    # RPN forward and loss
    if self.with_rpn:
        # 训练 RPN
        proposal_cfg = self.train_cfg.get('rpn_proposal',
                                          self.test_cfg.rpn)
        # 主要是调用 rpn_head 内部的 forward_train 方法
        rpn_losses, proposal_list = self.rpn_head.forward_train(x,...)
        losses.update(rpn_losses)
    else:
        proposal_list = proposals
    # 第二阶段,主要是调用 roi_head 内部的 forward_train 方法
    roi_losses = self.roi_head.forward_train(x, ...)
    losses.update(roi_losses)
    return losses

对于 SingleStageDetector 而言,其核心逻辑是:

#============= mmdet/models/detectors/single_stage.py/SingleStageDetector ============
def forward_train(...):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    # 主要是调用 bbox_head 内部的 forward_train 方法
    losses = self.bbox_head.forward_train(x, ...)
    return losses

如果在TwoStageDetector 和 SingleStageDetector基础上封装了其他类,就会调用新类中的forward_train函数。

测试流程

对于测试逻辑由于比较简单,就不详细描述了,简单来说测试流程下不需要 runner,直接加载训练好的权重,然后进行 model 推理即可,下面简要概述:

  1. 调用 MMDataParallel 或 MMDistributedDataParallel 中的 forward 方法;
  2. 调用 base.py 中的 forward 方法;
  3. 调用 base.py 中的 self.forward_test 方法;
  4. 如果是单尺度测试,则会调用 TwoStageDetector 或 SingleStageDetector 中的 simple_test 方法,如果是多尺度测试,则调用 aug_test 方法;
  5. 最终调用的是每个具体算法 Head 模块的 simple_test 或者 aug_test 方法。

猜你喜欢

转载自blog.csdn.net/weixin_43603658/article/details/131558980