Registro del mecanismo de registro de Pytorch

Registro del mecanismo de registro de Pytorch

Los bloques de código de registro a menudo aparecen en el código de muchas bibliotecas de código abierto de aprendizaje profundo.Por ejemplo, OpenMMlab , facebookresearch y BasicSR utilizan el mecanismo de registro. Este fragmento de código a menudo confunde a los principiantes que son nuevos en estas bibliotecas. Este blog analizará los principios y beneficios del mecanismo de registro.


1. Por qué usar el registro

Antes de explicar el principio del registro, introduzcamos por qué se usa el registro. La traducción china de registro es el registro. Para una biblioteca de código de aprendizaje profundo útil, generalmente hay múltiples funciones de pérdida integradas, múltiples estructuras de red y múltiples optimizadores. Al mismo tiempo, dichas bibliotecas generalmente admiten el análisis directo de la estructura del modelo y la estrategia de entrenamiento desde el archivo de configuración. Entonces, ¿cómo analizar con elegancia desde el archivo de configuración hasta la implementación del código específico? Este es el significado de introducir la operación de registro. En resumen, el registro es para facilitar la búsqueda de módulos relacionados.

2. Lectura del código de registro

Existen ligeras diferencias en la implementación de diferentes bases de código, pero el principio es el mismo, así que aquí tomamos BasicSR como ejemplo.

class Registry():
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.
    To create a registry (e.g. a backbone registry):
    .. code-block:: python
        BACKBONE_REGISTRY = Registry('BACKBONE')
    To register an object:
    .. code-block:: python
        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...
    Or:
    .. code-block:: python
        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name):
        """
        Args:
            name (str): the name of this registry
        """
        self._name = name
        self._obj_map = {
    
    }

    def _do_register(self, name, obj, suffix=None):
        if isinstance(suffix, str):
            name = name + '_' + suffix

        assert (name not in self._obj_map), (f"An object named '{
      
      name}' was already registered "
                                             f"in '{
      
      self._name}' registry!")
        self._obj_map[name] = obj

    def register(self, obj=None, suffix=None):
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not.
        See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = func_or_class.__name__
                self._do_register(name, func_or_class, suffix)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj, suffix)

    def get(self, name, suffix='basicsr'):
        ret = self._obj_map.get(name)
        if ret is None:
            ret = self._obj_map.get(name + '_' + suffix)
            print(f'Name {
      
      name} is not found, use name: {
      
      name}_{
      
      suffix}!')
        if ret is None:
            raise KeyError(f"No object named '{
      
      name}' found in '{
      
      self._name}' registry!")
        return ret

    def __contains__(self, name):
        return name in self._obj_map

    def __iter__(self):
        return iter(self._obj_map.items())

    def keys(self):
        return self._obj_map.keys()


DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')

El código anterior crea un objeto de registro para el conjunto de datos, la arquitectura, la red, la pérdida y la métrica. El código central está en la función de registro. La función de registro utiliza el diseño del decorador, es decir, siempre que @xx.register() esté decorado antes que el módulo de función, el módulo de función original se registrará y el módulo de función original se regresó eventualmente. , sin modificar su función original.

Como puede ver en el nivel inferior _do_register(), aquí se utiliza un diccionario para realizar la operación de registro, y los pares clave-valor registrados son el nombre del módulo y el propio módulo. De esta forma, después de leer la cadena del módulo en el archivo de configuración, podemos encontrar su implementación específica directamente a través del nombre de la función o el nombre de la clase.

El método de uso es el siguiente, solo necesita agregar decoraciones antes de esta clase, y puede encontrar directamente su implementación correspondiente desde la cadena L1Loss más adelante.

@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
    """L1 (mean absolute error, MAE) loss.
    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {
      
      reduction}. Supported ones are: {
      
      _reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
        """
        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)

Supongo que te gusta

Origin blog.csdn.net/weiman1/article/details/125610831
Recomendado
Clasificación