MMDetection Framework Introductory Tutorial (4): Detailed Registration Mechanism

  The previous blog introduced the configuration file in MMDetection, which mentioned that after we configure the model, data set, training strategy, etc. in the configuration file, the parameter information in the configuration file can be in the form of a dictionary through the Config class Manage it, and then the MMDetection framework will automatically analyze it to help us build the entire algorithm process. MMDetection uses the registration mechanism to realize the construction from configuration parameters to algorithm modules. This blog will start from the source code and introduce the registration mechanism in MMCV in detail.

  1. Official Documentation - MMCV
  2. Official Zhihu- MMCV core component analysis (5): Registry

1. Registrar

  The registration mechanism is a very important concept in MMCV. If you want to add your own algorithm modules or processes in MMDetection, you need to implement it through the registration mechanism.

1.1 Registry class

  Before introducing the registration mechanism, introduce the Registry class.

  MMCV uses the Registry to manage different modules with similar functions. For example, ResNet, FPN, and RoIHead all belong to the model structure, and SGD and Adam all belong to the optimizer. The registry is actually maintaining a global lookup table, the key is a string, and the value is a class.

  In simple terms, the register can be seen as a mapping from strings to classes . With the help of the register, users can query the corresponding class through a string and instantiate the class. With this understanding, it is easy to understand when I look at the source code of the Registry class. Let’s look at the constructor first. Its function is mainly to initialize the name of the registry, instantiate the function, and initialize a dictionary-type lookup table _module_dict:

from mmcv.utils import Registry

class Registry:
	# 构造函数
    def __init__(self, name, build_func=None, parent=None, scope=None):
        """
        name (str): 注册器的名字
        build_func(func): 从注册器构建实例的函数句柄
        parent (Registry): 父类注册器
        scope (str): 注册器的域名
        """
        self._name = name
        # 使用module_dict管理字符串到类的映射
        self._module_dict = dict()
        self._children = dict()
        # 如果scope未指定, 默认使用类定义位置所在的包名, 比如mmdet, mmseg
        self._scope = self.infer_scope() if scope is None else scope

        # build_func按照如下优先级初始化:
        # 1. build_func: 优先使用指定的函数
        # 2. parent.build_func: 其次使用父类的build_func
        # 3. build_from_cfg: 默认从config dict中实例化对象
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
            
        # 设置父类-子类的从属关系
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

  For example, we now want to use the registry to manage our model. First, initialize a Registry instance MODELS, and then call register_module()the method of the Registry class to complete the registration of the ResNet and VGG classes. You can see that MODELSthese two classes are included in the final print result The information (in the printed information corresponds to the items is actually self._module_dict), indicates that the registration is successful. For the sake of code brevity, it is recommended to use python's function decorator @to implement register_module()the call. Then we can build()instantiate our model through the function.

# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')

# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
    def __init__(self, depth):
        self.depth = depth
        print('Initialize ResNet{}'.format(depth))

# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)   
class FPN(object):
    def __init__(self, in_channel):
        self.in_channel= in_channel
        print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)

print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""

# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""

1.2 register_module and build function

  After instantiating a Registry object, the registration and instantiation of the class are completed through the register_moduleand function respectively. Let's take a look at the source code of these two functions.build

  register_module()What is actually called internally is self._register_module()a function, and the function is also very simple, which is to save the current module name and module type to be registered in the _module_dictquery table in the form of key->value pair.

def _register_module(self, module_class, module_name=None, force=False):
	"""
	module_class (class): 要注册的模块类型
	module_name (str): 要注册的模块名称
	force (bool): 是否强制注册
	"""
    if not inspect.isclass(module_class):
        raise TypeError('module must be a class, '
                        f'but got {
      
      type(module_class)}')
	
	# 如果未指定模块名称则使用默认名称
    if module_name is None:
        module_name = module_class.__name__
    # module_name为list形式, 从而支持在nn.Sequentail中构建pytorch模块
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
    	# 如果force=False, 则不允许注册相同名称的模块
    	# 如果force=True, 则用后一次的注册覆盖前一次
        if not force and name in self._module_dict:
            raise KeyError(f'{
      
      name} is already registered '
                           f'in {
      
      self.name}')
        # 将当前注册的模块加入到查询表中
        self._module_dict[name] = module_class

  buildFunctions point to build_func()functions (see Registry's constructor), which can be manually specified by the user when registering a module. However, since modules are generally registered with function decorators, all functions are actually build_func()called build_from_cfg(). build_from_cfg()Find the corresponding module type according to the type value in the configuration parameter obj_cls, then use the parameters in cfg and default_args to instantiate the corresponding module, and return the instantiated object to the upper-level build()function call.

def build_from_cfg(cfg, registry, default_args=None):
    """
    cfg (dict): 配置参数信息
    registry (Registry): 注册器
    """
    # cfg类型校验, 必须为字典类型
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {
      
      type(cfg)}')
    # cfg中必须要有type字段
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {
      
      cfg}\n{
      
      default_args}')
    # registry类型校验, 必须为Registry类型
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {
      
      type(registry)}')
    # default_args以字典的形式传入
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {
      
      type(default_args)}')

    args = cfg.copy()
	
	# 将cfg以外的外部传入参数也加入到args中
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
	# 获取模块名称
    obj_type = args.pop('type')
    if isinstance(obj_type, str):
    	# 根据模块名称获取到模块类型
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{
      
      obj_type} is not in the {
      
      registry.name} registry')
    elif inspect.isclass(obj_type):
    	# type值是模块本身
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {
      
      type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{
      
      obj_cls.__name__}: {
      
      e}')

2. Summary

  MMCV uses a register to manage different modules with similar functions. A register maintains a query table inside. Modules registered with this register will be stored in the query table in the form of key-value pairs. The register also provides instantiation method, returns the corresponding instantiated object according to the module name.

  Many commonly used registers have been built inside MMDetection, and the corresponding interface functions have been implemented, such as DETECTORS corresponding build_detector(), DATASETS corresponding build_dataset(), no matter what kind of xxx_build(), it is the calling Registry.build()function in the end. Most of the time we only need to use the ready-made register.

# MMDetection中的Registry
MODELS = Registry('models', parent=MMCV_MODELS)		# 从MMCV继承得到
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

Guess you like

Origin blog.csdn.net/qq_16137569/article/details/121216363