MMdetection之build_optimizer模块解读


前言

 前面文章build_dataset,build_dataloader,build_model均以做了详细的介绍,而optimizer作为“炼丹”的最后一个条件,本文将介绍mmdetection是如何构建优化器的。

1、总体流程

在这里插入图片描述
 总体流程和构建model过程类似。首先mmdetection建立了一个优化器注册器,里面注册了DefaultOptimizerConstructor优化器类。然后借助build_from_cfg函数从优化器配置字典中实例了一个optimizer对象。接下来,将详细介绍各个组件的内部原理。

2、优化器配置字典

 本文依旧以faster_rcnn_r50_fpn.py默认配置文件为例。其中,涉及optimizer的字段如下:

# optimizer
optimizer = dict(type='SGD', lr=0.00125, momentum=0.9, weight_decay=0.0001)

可以看出默认使用的是SGD优化器。

3、优化器注册器

 mmdetection建立了两个注册器:

OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')

3.1、注册器OPTIMIZERS

其中,OPTIMIZERS注册器里面添加了一些pytorch提供的优化器,可以看下图:
在这里插入图片描述
 这里我简单介绍下mmdetection构建这一块的过程(mmcv/runners/optimizer/builder.py):通过dir方法遍历torch.optim,然后利用register_module()(_optim)完成注册。

def register_torch_optimizers():
    torch_optimizers = []
    for module_name in dir(torch.optim):
        if module_name.startswith('__'):
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim,
                                                  torch.optim.Optimizer):
            OPTIMIZERS.register_module()(_optim)  #  此处往OPTIMIZERS里面注册了torch中默认的优化器。
            torch_optimizers.append(module_name)
    return torch_optimizers


TORCH_OPTIMIZERS = register_torch_optimizers()

3.2、注册器OPTIMIZER_BUILDERS

 另一个注册器主要注册了下面的这个类(mmcv/runner/optimizer/default_constructor.py),这里我仅仅截取了类初始化部分。

@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor:
    """Default constructor for optimizers

    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are

                - `type`: class name of the optimizer.

            Optional fields are

                - any arguments of the corresponding optimizer type, e.g.,
                  lr, weight_decay, momentum, etc.
        paramwise_cfg (dict, optional): Parameter-wise options.

    Example 1:
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001)
        >>> paramwise_cfg = dict(norm_decay_mult=0.)
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)

    Example 2:
        >>> # assume model have attribute model.backbone and model.cls_head
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
        >>> paramwise_cfg = dict(custom_keys={
                '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)
        >>> # Then the `lr` and `weight_decay` for model.backbone is
        >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
        >>> # model.cls_head is (0.01, 0.95).
    """

    def __init__(self, optimizer_cfg, paramwise_cfg=None):
        if not isinstance(optimizer_cfg, dict):
            raise TypeError('optimizer_cfg should be a dict',
                            f'but got {type(optimizer_cfg)}')
        self.optimizer_cfg = optimizer_cfg
        self.paramwise_cfg = {
    
    } if paramwise_cfg is None else paramwise_cfg
        self.base_lr = optimizer_cfg.get('lr', None)
        self.base_wd = optimizer_cfg.get('weight_decay', None)
        self._validate_cfg()

4、实例优化器对象

 在有了配置字典和注册器之后,然后就可以实例化优化器对象了。而构建优化器的入口在mmdet/apis/trian.py文件中,代码如下:

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)

 这里看下build_optimizer函数:

def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)       # 完成实例优化器对象

def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',  #optimizer_cfg字典中无"constructor"这个键,则返回DefaultOptimizerConstructor
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)  # 同上
    optim_constructor = build_optimizer_constructor(          # 实际上调用的是build_from_cfg函数
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return optimizer

 从代码中可以看出,build_optimizer内部调用了build_optimizer_constructor函数,进而调用了build_from_cfg完成了类的实例化,即代码中的optim_constructor对象。之后,到了有意思的部分了。前面我们仅仅用到了注册器OPTIMIZER_BUILDERS,却没有用到注册器OPTIMIZERS。那么是在哪里调用的呢?调用代码是:optimizer=optim_constructor(model)
 现在回头在看下3.2节中的DefaultOptimizerConstructor类。内部实现了__call__方法。截取这部分代码:

    def __call__(self, model):
        if hasattr(model, 'module'):
            model = model.module

        optimizer_cfg = self.optimizer_cfg.copy()
        # if no paramwise option is specified, just use the global setting
        if not self.paramwise_cfg:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)

        # set param-wise lr and weight decay recursively
        params = []
        self.add_params(params, model)
        optimizer_cfg['params'] = params

        return build_from_cfg(optimizer_cfg, OPTIMIZERS)

 代码中借助build_from_cfg(optimizer_cfg,OPTIMIZERS)完成了真正的优化器对象的建立。

总结

 本文主要介绍mmdetection中构建优化器的过程。当然,还有许多代码细节值得学习。总的来说,由于实际使用优化器过程中,会有各种各样灵活的设定。假如只借助单一的注册器OPTIMIZERS,势必会不方便。而mmdetection经过“工厂”—OPTIMIZER_BUILDERS就能给优化器提供灵活性(比如仅仅优化部分参数或者添加优化的参数等)。这种设计模式值得学习。

拓展阅读资料

mmcv中Config类介绍
mmcv之Registry类介绍
mmdetection之dataset类构建
mmdetection之dataloader类构建
mmdetection之model构建
mmdetection训练自己coco数据集

Guess you like

Origin blog.csdn.net/wulele2/article/details/115048828