MMDetection框架入门教程(四):注册机制详解

  上一篇博客对MMDetection中的配置文件进行了介绍,其中提到,我们在配置文件中配置到模型、数据集、训练策略等后,通过Config类可以将配置文件中的参数信息以字典的形式进行管理,然后MMDetection框架就会对其自动进行解析,帮助我们构建整个算法流程。MMDetection使用注册机制来实现从配置参数到算法模块的构建。 本篇博客将从源码出发,对MMCV中的注册机制进行详细介绍。

  1. 官方文档 - MMCV
  2. 官方知乎 - MMCV 核心组件分析(五): Registry

1. 注册器

  注册机制是MMCV中非常重要的一个概念,在MMDetection中如果你想要增加自己的算法模块或流程,都需要通过注册机制来实现。

1.1 Registry类

  介绍注册机制之前先介绍一下Registry类。

  MMCV使用注册器(Registry)来管理具有相似功能的不同模块,比如ResNet、FPN、RoIHead都属于模型结构,SGD、Adam都属于优化器。注册器内部其实是在维护一个全局的查询表,key是字符串,value是类。

  简单来说,注册器可以看做字符串到类(Class)的映射。借助注册器,用户可以通过字符串查询到对应的类,并实例化该类。有了这个认知后,我再看Registry类的源码就很容易理解了,先看下构造函数,其功能主要是初始化注册器的名字、实例化函数,并初始化一张字典类型的查询表_module_dict

from mmcv.utils import Registry

class Registry:
	# 构造函数
    def __init__(self, name, build_func=None, parent=None, scope=None):
        """
        name (str): 注册器的名字
        build_func(func): 从注册器构建实例的函数句柄
        parent (Registry): 父类注册器
        scope (str): 注册器的域名
        """
        self._name = name
        # 使用module_dict管理字符串到类的映射
        self._module_dict = dict()
        self._children = dict()
        # 如果scope未指定, 默认使用类定义位置所在的包名, 比如mmdet, mmseg
        self._scope = self.infer_scope() if scope is None else scope

        # build_func按照如下优先级初始化:
        # 1. build_func: 优先使用指定的函数
        # 2. parent.build_func: 其次使用父类的build_func
        # 3. build_from_cfg: 默认从config dict中实例化对象
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
            
        # 设置父类-子类的从属关系
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

  比如说,我们现在想要使用注册器来管理我们的模型,首先初始化一个Registry实例MODELS,然后调用Registry类的register_module()方法完成ResNet和VGG类的注册,可以看到最后MODELS的打印结果中包含了这两个类的信息(打印信息中items对应的其实就是self._module_dict),表示注册成功。为了代码简洁,这里推荐使用python的函数装饰器@实现register_module()的调用。然后就可以通过build()函数来实例化我们的模型了。

# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')

# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
    def __init__(self, depth):
        self.depth = depth
        print('Initialize ResNet{}'.format(depth))

# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)   
class FPN(object):
    def __init__(self, in_channel):
        self.in_channel= in_channel
        print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)

print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""

# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""

1.2 register_module和build函数

  在实例化一个Registry对象后,类的注册和实例化分别通过register_modulebuild函数完成的,下面来看看这两个函数的源码。

  register_module()内部实际上调用的是self._register_module()函数,功能也很简单,就是将当前要注册的模块名称和模块类型以键值对key->value的形式保存到_module_dict查询表中。

def _register_module(self, module_class, module_name=None, force=False):
	"""
	module_class (class): 要注册的模块类型
	module_name (str): 要注册的模块名称
	force (bool): 是否强制注册
	"""
    if not inspect.isclass(module_class):
        raise TypeError('module must be a class, '
                        f'but got {
      
      type(module_class)}')
	
	# 如果未指定模块名称则使用默认名称
    if module_name is None:
        module_name = module_class.__name__
    # module_name为list形式, 从而支持在nn.Sequentail中构建pytorch模块
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
    	# 如果force=False, 则不允许注册相同名称的模块
    	# 如果force=True, 则用后一次的注册覆盖前一次
        if not force and name in self._module_dict:
            raise KeyError(f'{
      
      name} is already registered '
                           f'in {
      
      self.name}')
        # 将当前注册的模块加入到查询表中
        self._module_dict[name] = module_class

  build函数是指向build_func()函数的(见Registry的构造函数),可以在模块注册的时候由用户手动指定,但由于模块一般都是用函数装饰器的方式来注册,所以build_func()实际上调用的都是build_from_cfg()函数。build_from_cfg()根据配置参数中的type值找到对应的模块类型obj_cls,然后使用cfg和default_args中的参数实例化对应的模块,并返回实例化对象给上级的build()函数调用。

def build_from_cfg(cfg, registry, default_args=None):
    """
    cfg (dict): 配置参数信息
    registry (Registry): 注册器
    """
    # cfg类型校验, 必须为字典类型
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {
      
      type(cfg)}')
    # cfg中必须要有type字段
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {
      
      cfg}\n{
      
      default_args}')
    # registry类型校验, 必须为Registry类型
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {
      
      type(registry)}')
    # default_args以字典的形式传入
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {
      
      type(default_args)}')

    args = cfg.copy()
	
	# 将cfg以外的外部传入参数也加入到args中
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
	# 获取模块名称
    obj_type = args.pop('type')
    if isinstance(obj_type, str):
    	# 根据模块名称获取到模块类型
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{
      
      obj_type} is not in the {
      
      registry.name} registry')
    elif inspect.isclass(obj_type):
    	# type值是模块本身
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {
      
      type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{
      
      obj_cls.__name__}: {
      
      e}')

2. 小结

  MMCV使用注册器来管理具有相似功能的不同模块,一个注册器内部会维护一个查询表,使用该注册器注册的模块都会以键值对的形式保存在这个查询表中,注册器还提供实例化方法,根据模块名称返回对应的实例化对象。

  MMDetection内部已经构建了许多常用的注册器,并实现了对应的接口函数,比如DETECTORS对应build_detector(),DATASETS对应build_dataset(),无论是什么样的xxx_build(),最终都是调用Registry.build()函数。我们绝大多数时候只需要使用现成的注册器即可。

# MMDetection中的Registry
MODELS = Registry('models', parent=MMCV_MODELS)		# 从MMCV继承得到
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

猜你喜欢

转载自blog.csdn.net/qq_16137569/article/details/121216363