MMDetection学习笔记(四):核心组件分析

核心组件分析

此篇博客注重分析了MMDetection中三大核心组件:Registry、Hook和Runner。

Registry

Registry 机制其实维护的是一个全局字典,实现字符串到类的映射。通过 Registry 类,用户可以通过config中字符串的方式实例化任何想要的类(或模块)。Registry的优点在于:解耦性强、可扩展性强,代码更易理解。

MMCV中Registry类的实现源码:

class Registry:
    def __init__(self, name):
        # 可实现注册类细分功能
        self._name = name 
        # 内部核心内容,维护所有的已经注册好的 class
        self._module_dict = dict()

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {
      
      type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{
      
      module_name} is already registered '
                           f'in {
      
      self.name}')
        # 最核心代码
        self._module_dict[module_name] = module_class

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:
            # 如果已经是 module,那就知道 增加到字典中即可
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # 最标准用法
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

在 MMCV 中所有的类实例化都是通过build_from_cfg函数实现,做的事情非常简单,就是给定module_name,然后从 self._module_dict 提取即可。

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type') # 注册 str 类名
    if is_str(obj_type):
        # 相当于 self._module_dict[obj_type]
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{
      
      obj_type} is not in the {
      
      registry.name} registry')

    # 如果已经实例化了,那就直接返回
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {
      
      type(obj_type)}')

    # 最终初始化对于类,并且返回,就完成了一个类的实例化过程
    return obj_cls(**args)

一个完整的使用例子如下:

# registry
CONVERTERS = Registry('converter')

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

# config
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)

Hook

Hook的定义

在 wiki 百科中定义如下:

钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

简单来说,Hook机制可以在代码运行的整个生命周期中无侵入地拓展功能。Hook 机制在 OpenMMLab 系列框架中应用非常广泛,结合 Runner 类可以实现对训练过程的整个生命周期进行管理。同时内置了多种 Hook,通过注册的形式注入 Runner 中实现了丰富的扩展功能,例如模型权重保存、日志记录、lr超参数的调整等等。

Hook的调用机制

在MMDetection中,Hook 是可以注册进 Runner 中,不同类型的 Hook 实现了不同的生命周期方法从而完成不同的功能,以一个典型的训练过程为例,EpochBasedRunner(以 epoch 为单位) 中生命周期方法如下所示:

# 开始运行时调用
before_run()

while self.epoch < self._max_epochs:

    # 开始 epoch 迭代前调用
    before_train_epoch()

    for i, data_batch in enumerate(self.data_loader):
        # 开始一次(iteration)迭代前调用
        before_train_iter()

        self.model.train_step()

        # 经过一次(iteration)迭代后调用
        after_train_iter()

    # 经过一个 epoch 迭代后调用
    after_train_epoch()

# 运行完成前调用
after_run()

只要注册的 Hook 对象实现了某一个或者某几个生命周期方法,当 Runner 运行到预定义的位点时候就会调用对应的 Hook 中方法。

Hook的分类与用法

MMCV中实现的Hook有默认Hook和定制Hook,默认 Hook不需要用户自行注册,用户通过 (hook 名)_config 配置对应参数即可;而对于定制 Hook,则需要用户手动注册或者通过配置方式注册进去。

对于默认 Hook,在 MMDetection 框架训练过程中,其注册代码为:

runner.register_training_hooks(cfg.lr_config, optimizer_config,
                               cfg.checkpoint_config, cfg.log_config,
                               cfg.get('momentum_config', None))

register_training_hooks函数的接收参数其实是字典参数,Runner 内部会根据配置自动生成对应的 Hook 实例,典型的 lr_config 为:

lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[16, 22])

对于定制类 Hook,其注册源码如下:

# user-defined hooks
if cfg.get('custom_hooks', None):
    custom_hooks = cfg.custom_hooks
    for hook_cfg in cfg.custom_hooks:
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        # 通过配置实例化定制 hook
        hook = build_from_cfg(hook_cfg, HOOKS)
        # 注册
        runner.register_hook(hook, priority=priority)

以 EMAHook 为例,其 .py 配置文件应该写成:

custom_hooks=[dict(type='EMAHook')]

下面对一些比较通用的、常用的 Hook 进行功能简析:

  1. CheckpointHook
    CheckpointHook 主要是对模型参数进行保存,如果是分布式多卡训练,则仅仅会在 master 进程保存。同时可以通过max_keep_ckpts参数设置最多保存多少个权重文件,早期额外的权重会自动删除。
    如果以 epoch 为单位进行保存,则该 Hook 实现after_train_epoch方法即可,否则仅实现after_train_iter 方法即可。
  2. LrUpdaterHook
    LrUpdaterHook 用于学习率调度,为了统一代码风格以及方便扩展,MMDetection 等训练框架并没有直接继承 PyTorch 提供的学习率调度器,而是通过 LrUpdaterHook 实现。
    如果是以 iter 为单位,则仅仅需要在 before_train_iter 方法中实现学习率调度功能,如果是以 epoch 为单位,则还需要在 before_train_epoch中实现相关操作。简单来说要实时改变学习率。
  3. OptimizerHook
    OptimizerHook 功能比较简单:梯度反向传播加上参数更新,如果指定了梯度裁剪参数,则可以进行梯度裁剪。
  4. ClosureHook
    ClosureHook 比较特殊,他的主要功能是提供最简洁的函数注册。
    可以想象一个场景:在训练过程中,想知道目前的迭代次数,在目前框架体系下最优雅的实现方式是:用户自己写一个获取 iter 的 Hook 类,然后在配置文件中通过 custom_hooks注册进去,该类的代码如下所示:
    @HOOKS.register_module()
    class GetIterHook(Hook):
        def after_train_iter(self, runner):
            print(runner.iter)
    
    可以发现你需要做如下事情:
    (1)写一个 GetIterHook 类,继承自 Hook;
    (2)在类上方加上 @HOOKS.register_module();
    (3)在对应的 init.py 文件中进行导入;
    (4)将该 Hook 注册到 Runner 中。
    需要完成三个步骤,但是实际上我只是想 print 而已,比较繁琐,而 ClosureHook 的作用就是为了简化流程。你现在要做的事情如下所示:
    def getiter(runner):
        print(runner.iter)
    
    (1)定义如上函数;
    (2)作为参数输入给 ClosureHook,并且实例化 ClosureHook(‘after_train_iter’, getiter);
    (3)将该 Hook 注册到 Runner 中。
    ClosureHook 主要用于一些非常简单的 Hook,但是又不想重新定义一个类来实现,此时就可以通过定义函数,然后传递给 ClosureHook 即可。

Runner

Runner负责OpenMMLab中所有框架pipeline的过程调度,提供了 以Epoch 和 Iter 为基础的迭代模式以满足不同场景,例如 MMDetection 默认采用 Epoch (配置文件中相关参数都是以 Epoch 为单位),而 MMSegmentation 默认采用 Iter (配置文件中相关参数都是以 Iter 为单位)。配合各类 Hook,以一种优雅的方式实现功能的扩展。

Runner 的使用过程可以分成 4 个步骤:

  1. Runner 对象初始化;
  2. 注册各类 Hook 到 Runner 中;
  3. 调用 Runner 的 resume 或者 load_checkpoint 方法对权重进行加载;
  4. 运行给定的pipeline工作流。

Runner 初始化

考虑到 Epoch 和 Iter 模式有很多共有逻辑,为了复用,抽象出一个 BaseRunner。BaseRunner 初始化是一个常规初始化过程,其参数如下:

def __init__(self,
             model,
             batch_processor=None, # 已废弃
             optimizer=None,
             work_dir=None,
             logger=None,
             meta=None, # 提供了该参数,则会保存到 ckpt 中
             max_iters=None, # 这两个参数非常关键,如果没有给定,则内部自己计算
             max_epochs=None):

注册 Hook

register_training_hooks,注册默认Hook:

def register_training_hooks(self,
                            lr_config, # lr相关
                            optimizer_config=None, # 优化器相关
                            checkpoint_config=None, # ckpt 保存相关
                            log_config=None, # 日志记录相关
                            momentum_config=None, # momentum 相关
                            timer_config=dict(type='IterTimerHook')) # 迭代时间统计

register_hook,上面以外的其他所有 Hook,都是通过本方式进行注册,例如 eval_hook、custom_hooks 和 DistSamplerSeedHook 等等:

def register_hook(self, hook, priority='NORMAL'):
    # 获取优先级
    priority = get_priority(priority)
    hook.priority = priority
    # 基于优先级计算当前 hook 插入位置
    inserted = False
    for i in range(len(self._hooks) - 1, -1, -1):
        if priority >= self._hooks[i].priority:
            self._hooks.insert(i + 1, hook)
            inserted = True
            break
    if not inserted:
        self._hooks.insert(0, hook)

resume 或者 load_checkpoint

resume 方法用于训练过程中停止然后恢复训练时加载权重,而 load_checkpoint 仅仅是加载预训练权重,这个预训练权重可以来自官方,也可以来自自己训练后的权重,如果有 key 不匹配的参数则会自动跳过。

run

run 方法调用后才是真正开启工作流,并且由于 Epoch 和 Iter 模式流程不一样,所以在各自子类实现。

(1) EpochBasedRunner run

def run(self, 
    data_loaders, # dataloader 列表
    workflow,  # 工作流列表,长度需要和 data_loaders 一致
    max_epochs=None, 
    **kwargs):
  • 假设只想运行训练工作流,则可以设置 workflow = [(‘train’, 1)],表示 data_loader 中的数据进行迭代训练
  • 假设想运行训练和验证工作流,则可以设置 workflow = [(‘train’, 3), (‘val’,1)],表示先训练 3 个 epoch ,然后切换到 val 工作流,运行 1 个 epoch,然后循环,直到训练 epoch 次数达到指定值
  • 工作流设置非常自由,例如你可以先验证再训练 workflow = [(‘val’, 1), (‘train’,1)]

需要注意的是:如果工作流有两个,那么 data_loaders 中也需要提供两个 dataloader。其核心逻辑如下:

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)

    # epoch 模式,需要自动计算出 _max_iters
    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    # 调用注册到 runner 中的所有 hook 的 before_run 方法,表示开启 run 前
    self.call_hook('before_run')

    # 如果没有达到退出条件,就一直运行工作流
    while self.epoch < self._max_epochs:
        # 遍历工作流
        for i, flow in enumerate(workflow):
            # 模式,和当前工作流需要运行的 epoch 次数
            mode, epochs = flow
            epoch_runner = getattr(self, mode)
            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break
                # 开始一个 epoch 的迭代
                epoch_runner(data_loaders[i], **kwargs)
    time.sleep(1)  # wait for some hooks like loggers to finish

    # 调用注册到 runner 中的所有 hook 的 after_run 方法,表示结束 run 后
    self.call_hook('after_run')

run 方法中定义的是通用工作流切换流程,真正完成一个 epoch 工作流是调用了工作流函数。目前支持 train 和 val 两个工作流,那么 epoch_runner(data_loaders[i], **kwargs) 调用的实际上是 train 或者 val 方法:

# train 和 val 方法逻辑非常相似
def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True)
        self.call_hook('after_train_iter')
        self._iter += 1
    self.call_hook('after_train_epoch')
    self._epoch += 1

@torch.no_grad()
def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_val_iter')
        self.run_iter(data_batch, train_mode=False)
        self.call_hook('after_val_iter')
    self.call_hook('after_val_epoch')

上述逻辑是遍历 data_loader,然后进行 batch 级别的迭代训练或者验证,比较容易理解。真正完成一个 batch 的训练或者验证是调用了 self.run_iter

# 简化逻辑
def run_iter(self, data_batch, train_mode, **kwargs):
    # 调用 model 自身的 train_step 或者 val_step 方法
    if train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer,
                                        **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])

    self.outputs = outputs

(2) IterBasedRunner run

IterBasedRunner模式以迭代次数作为循环终止的条件, 没有 epoch 的概念,故 IterBasedRunner 的 run 方法有些许改动:

  • 工作流终止条件不再是 epoch,而是 iter
  • Hook 的生命周期方法也不涉及 epoch,全部是 iter 相关方法

由于MMDetection采用的EpochBasedRunner,而非IterBasedRunner,其详细的代码逻辑不再展开。

(3)EpochBasedRunner与IterBasedRunner比较

假设数据长度是 1024,batch=4,那么 dataloader 长度是 1024/4=256, 也就是一个 epoch 是 256 次迭代,在 Iter 训练模式下,计划训练 100000 个迭代,若在Epoch训练模式下,那么实际上运行了 100000//256=39 个 epoch。

猜你喜欢

转载自blog.csdn.net/weixin_43603658/article/details/129113196