Pytorch注册器机制Registry

Pytorch注册器机制Registry

在众多深度学习开源库的代码中经常出现Registry代码块,例如OpenMMlabfacebookresearchBasicSR中都使用了注册器机制。这块的代码经常会让新使用这些库的初学者感到一头雾水,本篇博客来分析一下注册器机制的原理与好处。


1. 为什么使用registry

在讲解registry原理前,我们先介绍一下,为何使用registry。registry的中文翻译是注册器。对于一个好用的深度学习代码库来说,通常都会内置多种损失函数,多种网络结构,以及多种优化器等。同时这类的库一般都支持从配置文件中,直接解析出模型结构与训练策略。那么如何优雅的从配置文件解析到具体的代码实现呢?这就是引入注册操作的意义,简而言之,注册器是为了方便找到相关模块。

2. registry代码阅读

在实现上不同代码库略有差异,但原理相同,所以这里就以BasicSR为例。

class Registry():
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.
    To create a registry (e.g. a backbone registry):
    .. code-block:: python
        BACKBONE_REGISTRY = Registry('BACKBONE')
    To register an object:
    .. code-block:: python
        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...
    Or:
    .. code-block:: python
        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name):
        """
        Args:
            name (str): the name of this registry
        """
        self._name = name
        self._obj_map = {
    
    }

    def _do_register(self, name, obj, suffix=None):
        if isinstance(suffix, str):
            name = name + '_' + suffix

        assert (name not in self._obj_map), (f"An object named '{
      
      name}' was already registered "
                                             f"in '{
      
      self._name}' registry!")
        self._obj_map[name] = obj

    def register(self, obj=None, suffix=None):
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not.
        See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = func_or_class.__name__
                self._do_register(name, func_or_class, suffix)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj, suffix)

    def get(self, name, suffix='basicsr'):
        ret = self._obj_map.get(name)
        if ret is None:
            ret = self._obj_map.get(name + '_' + suffix)
            print(f'Name {
      
      name} is not found, use name: {
      
      name}_{
      
      suffix}!')
        if ret is None:
            raise KeyError(f"No object named '{
      
      name}' found in '{
      
      self._name}' registry!")
        return ret

    def __contains__(self, name):
        return name in self._obj_map

    def __iter__(self):
        return iter(self._obj_map.items())

    def keys(self):
        return self._obj_map.keys()


DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')

上面的代码为数据集,架构,网络,损失以及度量方式都创建了一个注册器对象。核心代码在register函数里,register函数使用了装饰器的设计,也就是只要在功能模块前进行@xx.register()进行装饰,就会对原有功能模块进行注册,并且最终返回原始的功能模块,不修改其原有功能。

在更下层的_do_register()中可以看到,这里使用的是一个字典来执行注册操作,记录的键值对分别是模块的名称以及模块本身。这样一来,读取配置文件中的模块字符串后,我们就能够直接通过函数名或者类名找到其具体实现。

使用方法如下所示,只需要在此类前加上装饰,后期则直接能够从字符串L1Loss找到其对应的实现。

@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
    """L1 (mean absolute error, MAE) loss.
    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {
      
      reduction}. Supported ones are: {
      
      _reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
        """
        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)

猜你喜欢

转载自blog.csdn.net/weiman1/article/details/125610831
今日推荐