Registry注册机制

前言:不管是Detectron还是mmdetection,都有用到这个register机制,特意去弄明白,记录一下。

首先看Registry代码:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, Optional, Iterable, Tuple, Iterator

from tabulate import tabulate


class Registry(Iterable[Tuple[str, object]]):
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name: str) -> None:
        """
        Args:
            name (str): the name of this registry
        """
        self._name: str = name
        self._obj_map: Dict[str, object] = {}

    def _do_register(self, name: str, obj: object) -> None:
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(
            name, self._name
        )
        self._obj_map[name] = obj

    def register(self, obj: object = None) -> Optional[object]:
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class: object) -> object:
                name = func_or_class.__name__  # pyre-ignore
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__  # pyre-ignore
        self._do_register(name, obj)

    def get(self, name: str) -> object:
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError(
                "No object named '{}' found in '{}' registry!".format(name, self._name)
            )
        return ret

    def __contains__(self, name: str) -> bool:
        return name in self._obj_map

    def __repr__(self) -> str:
        table_headers = ["Names", "Objects"]
        table = tabulate(
            self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
        )
        return "Registry of {}:\n".format(self._name) + table

    def __iter__(self) -> Iterator[Tuple[str, object]]:
        return iter(self._obj_map.items())

    # pyre-fixme[4]: Attribute must be annotated.
    __str__ = __repr__

可看出register方法就是通过调用_do_register将函数名称或者类名称,函数地址或者类地址做成一个字典,在通过get方法获取函数或者类功能。

示例代码调用:


from fvcore.common.registry import Registry

BACKBONE_REGISTRY = Registry("BACKBONE")

@BACKBONE_REGISTRY.register()
def test_register(cfg):
    print('==cfg:', cfg)
    return '==test_register is called'


def debug_register():
    cfg = 'hahahah'
    print(BACKBONE_REGISTRY.get('test_register'))##返回函数或者类对象
    res = BACKBONE_REGISTRY.get('test_register')(cfg)#调用函数或者类功能
    print('==res:', res)

if __name__ == '__main__':
    debug_register()

猜你喜欢

转载自blog.csdn.net/fanzonghao/article/details/110638006
今日推荐