Notes d'étude MMDetection (4): Analyse des composants de base

Analyse des composants de base

Ce blog se concentre sur l'analyse des trois composants principaux de MMDetection : Registry, Hook et Runner.

Enregistrement

Le mécanisme de registre maintient en fait un dictionnaire global qui mappe les chaînes aux classes. Grâce à la classe Registry, les utilisateurs peuvent instancier n'importe quelle classe (ou module) souhaitée via la chaîne dans config. Les avantages de Registry sont : un fort découplage, une forte évolutivité et un code plus facile à comprendre.

Le code source d'implémentation de la classe Registry dans MMCV :

class Registry:
    def __init__(self, name):
        # 可实现注册类细分功能
        self._name = name 
        # 内部核心内容,维护所有的已经注册好的 class
        self._module_dict = dict()

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {
      
      type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{
      
      module_name} is already registered '
                           f'in {
      
      self.name}')
        # 最核心代码
        self._module_dict[module_name] = module_class

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:
            # 如果已经是 module,那就知道 增加到字典中即可
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # 最标准用法
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

Toutes les instanciations de classe dans MMCV sont build_from_cfgimplémentées via des fonctions, et ce qu'elles font est très simple, c'est-à-dire données module_namepuis self._module_dict extraites.

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type') # 注册 str 类名
    if is_str(obj_type):
        # 相当于 self._module_dict[obj_type]
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{
      
      obj_type} is not in the {
      
      registry.name} registry')

    # 如果已经实例化了,那就直接返回
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {
      
      type(obj_type)}')

    # 最终初始化对于类,并且返回,就完成了一个类的实例化过程
    return obj_cls(**args)

Un exemple d'utilisation complet est le suivant :

# registry
CONVERTERS = Registry('converter')

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

# config
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)

Accrocher

Définition du crochet

Il est défini dans wikipedia comme suit :

L'accrochage, également appelé "accrochage", est un terme de programmation informatique qui fait référence à la modification ou à l'extension du comportement des systèmes d'exploitation, des applications ou d'autres composants logiciels en interceptant les appels de fonction, la transmission de messages et la transmission d'événements entre des modules logiciels et diverses techniques. Le code qui gère les appels de fonction, les événements et les messages interceptés est appelé un crochet

En termes simples, le mécanisme Hook peut étendre les fonctions de manière non invasive tout au long du cycle de vie de l'exécution du code. Le mécanisme Hook est largement utilisé dans la série de frameworks OpenMMLab et, combiné à la classe Runner, il peut gérer l'ensemble du cycle de vie du processus de formation. Dans le même temps, une variété de crochets sont intégrés, qui sont injectés dans le Runner sous forme d'enregistrement pour réaliser des fonctions d'extension riches, telles que l'économie de poids du modèle, la journalisation, l'ajustement de l'hyperparamètre lr, etc.

Mécanisme d'appel de crochet

Dans MMDetection, les Hooks peuvent être enregistrés dans le Runner. Différents types de Hooks implémentent différentes méthodes de cycle de vie pour remplir différentes fonctions. En prenant un processus de formation typique comme exemple, les méthodes de cycle de vie dans EpochBasedRunner (à l'époque) sont les suivantes.

# 开始运行时调用
before_run()

while self.epoch < self._max_epochs:

    # 开始 epoch 迭代前调用
    before_train_epoch()

    for i, data_batch in enumerate(self.data_loader):
        # 开始一次(iteration)迭代前调用
        before_train_iter()

        self.model.train_step()

        # 经过一次(iteration)迭代后调用
        after_train_iter()

    # 经过一个 epoch 迭代后调用
    after_train_epoch()

# 运行完成前调用
after_run()

Tant que l'objet Hook enregistré implémente une ou plusieurs méthodes de cycle de vie, lorsque le Runner s'exécutera vers une position prédéfinie, il appellera la méthode Hook correspondante.

Classification et utilisation de Hook

Les crochets implémentés dans MMCV incluent des crochets par défaut et des crochets personnalisés. Les crochets par défaut ne nécessitent pas que les utilisateurs s'enregistrent eux-mêmes, et les utilisateurs peuvent configurer les paramètres correspondants via (nom du crochet)_config ; pour les crochets personnalisés, les utilisateurs doivent s'enregistrer manuellement ou via la configuration.

Pour le Hook par défaut, lors du processus d'apprentissage du framework MMDetection, son code d'enregistrement est :

runner.register_training_hooks(cfg.lr_config, optimizer_config,
                               cfg.checkpoint_config, cfg.log_config,
                               cfg.get('momentum_config', None))

register_training_hooksLes paramètres de réception de la fonction sont en fait des paramètres de dictionnaire. Le Runner générera automatiquement l'instance de Hook correspondante en fonction de la configuration. Le lr_config typique est :

lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[16, 22])

Pour les Hooks personnalisés, le code source d'enregistrement est le suivant :

# user-defined hooks
if cfg.get('custom_hooks', None):
    custom_hooks = cfg.custom_hooks
    for hook_cfg in cfg.custom_hooks:
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        # 通过配置实例化定制 hook
        hook = build_from_cfg(hook_cfg, HOOKS)
        # 注册
        runner.register_hook(hook, priority=priority)

En prenant EMAHook comme exemple, son fichier de configuration .py doit être écrit comme suit :

custom_hooks=[dict(type='EMAHook')]

Ce qui suit est une brève analyse fonctionnelle de certains crochets courants et couramment utilisés :

  1. CheckpointHook
    CheckpointHook sert principalement à enregistrer les paramètres du modèle. S'il s'agit d'une formation multi-cartes distribuée, il ne sera enregistré que dans le processus maître. Dans le même temps, vous pouvez max_keep_ckptsdéfinir le nombre maximum de fichiers de poids à enregistrer via les paramètres, et les poids supplémentaires au stade initial seront automatiquement supprimés.
    Si l'enregistrement est effectué en unités d'époque, alors le Hook after_train_epochpeut implémenter la méthode, sinon il after_train_iter n'a qu'à implémenter la méthode.
  2. LrUpdaterHook
    LrUpdaterHook est utilisé pour la planification du taux d'apprentissage Afin d'unifier le style de code et de faciliter l'expansion, les frameworks de formation tels que MMDetection n'héritent pas directement du planificateur de taux d'apprentissage fourni par PyTorch, mais l'implémentent via LrUpdaterHook.
    Si l'unité est iter, il vous suffit d' before_train_iter implémenter la fonction de planification du taux d'apprentissage dans la méthode, et si l'unité est epoch, vous devez également implémenter before_train_epochles opérations associées dans . En termes simples, le taux d'apprentissage doit être modifié en temps réel.
  3. OptimizerHook
    La fonction d'OptimizerHook est relativement simple : rétropropagation du gradient plus mise à jour des paramètres, si le paramètre d'écrêtage du dégradé est spécifié, l'écrêtage du dégradé peut être effectué.
  4. ClosureHook
    ClosureHook est spécial, sa fonction principale est de fournir l'enregistrement de fonction le plus concis. Vous pouvez imaginer un scénario : pendant le processus de formation, si vous souhaitez connaître le nombre actuel d'itérations, la méthode d'implémentation la plus élégante sous le système de framework actuel est : l'utilisateur écrit une classe Hook pour obtenir iter, puis l'enregistre
    dans le fichier de configuration. custom_hooksLe code ressemble à ceci :
    @HOOKS.register_module()
    class GetIterHook(Hook):
        def after_train_iter(self, runner):
            print(runner.iter)
    
    Il peut être trouvé que vous devez faire ce qui suit :
    (1) Écrivez une classe GetIterHook qui hérite de Hook ;
    (2) Ajoutez @HOOKS.register_module() au-dessus de la classe ; (3) Importez-le dans le fichier init.py
    correspondant ; (4) Enregistrez le crochet sur le coureur. Trois étapes sont à franchir, mais en fait je veux juste imprimer, ce qui est fastidieux, et le rôle de ClosureHook est de simplifier le processus. Ce que vous devez faire maintenant est le suivant :

    def getiter(runner):
        print(runner.iter)
    
    (1) Définissez la fonction ci-dessus ;
    (2) Saisissez-la dans ClosureHook en tant que paramètre et instanciez ClosureHook('after_train_iter', getiter) ;
    (3) Enregistrez le Hook dans le Runner.
    ClosureHook est principalement utilisé pour certains crochets très simples, mais vous ne voulez pas redéfinir une classe pour y parvenir. À ce stade, vous pouvez définir une fonction et la transmettre à ClosureHook.

Coureur

Runner est responsable de la planification des processus de tous les pipelines du framework dans OpenMMLab et fournit des modes d'itération basés sur Epoch et Iter pour répondre à différents scénarios. Par exemple, MMDetection utilise Epoch par défaut (les paramètres pertinents dans le fichier de configuration sont tous en unités Epoch). , tandis que MMSegmentation utilise par défaut Epoch.Iter (les paramètres pertinents dans le fichier de configuration sont tous en unité Iter). Coopérez avec divers crochets pour réaliser une extension fonctionnelle de manière élégante.

Le processus d'utilisation de Runner peut être divisé en 4 étapes :

  1. Initialisation de l'objet Runner ;
  2. Enregistrez divers crochets pour le coureur ;
  3. Appelez la méthode resume ou load_checkpoint de Runner pour charger le poids ;
  4. Exécutez le workflow de pipeline donné.

Initialisation du coureur

Considérant que les modes Epoch et Iter ont beaucoup de logique commune, un BaseRunner est abstrait pour être réutilisé. L'initialisation de BaseRunner est un processus d'initialisation standard avec les paramètres suivants :

def __init__(self,
             model,
             batch_processor=None, # 已废弃
             optimizer=None,
             work_dir=None,
             logger=None,
             meta=None, # 提供了该参数,则会保存到 ckpt 中
             max_iters=None, # 这两个参数非常关键,如果没有给定,则内部自己计算
             max_epochs=None):

Enregistrez le crochet

register_training_hooks, enregistrez le crochet par défaut :

def register_training_hooks(self,
                            lr_config, # lr相关
                            optimizer_config=None, # 优化器相关
                            checkpoint_config=None, # ckpt 保存相关
                            log_config=None, # 日志记录相关
                            momentum_config=None, # momentum 相关
                            timer_config=dict(type='IterTimerHook')) # 迭代时间统计

register_hook, tous les autres crochets autres que ceux ci-dessus sont enregistrés via cette méthode, tels que eval_hook, custom_hooks et DistSamplerSeedHook, etc. :

def register_hook(self, hook, priority='NORMAL'):
    # 获取优先级
    priority = get_priority(priority)
    hook.priority = priority
    # 基于优先级计算当前 hook 插入位置
    inserted = False
    for i in range(len(self._hooks) - 1, -1, -1):
        if priority >= self._hooks[i].priority:
            self._hooks.insert(i + 1, hook)
            inserted = True
            break
    if not inserted:
        self._hooks.insert(0, hook)

reprendre ou load_checkpoint

La méthode de reprise est utilisée pour charger des poids lorsque le processus d'entraînement est arrêté puis repris, tandis que load_checkpoint ne sert qu'à charger des poids pré-entraînés. Les poids pré-entraînés peuvent provenir de l'officiel ou de leurs propres poids entraînés. S'il existe des paramètres qui ne correspond pas à la clé, elle sautera automatiquement.

courir

Une fois la méthode d'exécution appelée, le flux de travail est réellement démarré et, comme les processus des modes Epoch et Iter sont différents, ils sont implémentés dans leurs sous-classes respectives.

(1) Exécution d'EpochBasedRunner

def run(self, 
    data_loaders, # dataloader 列表
    workflow,  # 工作流列表,长度需要和 data_loaders 一致
    max_epochs=None, 
    **kwargs):
  • En supposant que vous souhaitez uniquement exécuter le workflow de formation, vous pouvez définir workflow = [('train', 1)], indiquant que les données dans data_loader sont formées de manière itérative
  • Supposons que vous souhaitiez exécuter des flux de travail d'entraînement et de vérification, vous pouvez définir le flux de travail = [('train', 3), ('val',1)], ce qui signifie d'abord s'entraîner pendant 3 époques, puis passer au flux de travail val et exécuter pendant 1 époque, puis boucle jusqu'à ce que le nombre d'époques d'entraînement atteigne la valeur spécifiée
  • Le paramètre de workflow est très libre, par exemple, vous pouvez d'abord vérifier puis former workflow = [('val', 1), ('train',1)]

Il convient de noter que s'il y a deux workflows, deux chargeurs de données doivent également être fournis dans data_loaders. Sa logique de base est la suivante :

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)

    # epoch 模式,需要自动计算出 _max_iters
    for i, flow in enumerate(workflow):
        mode, epochs = flow
        if mode == 'train':
            self._max_iters = self._max_epochs * len(data_loaders[i])
            break

    # 调用注册到 runner 中的所有 hook 的 before_run 方法,表示开启 run 前
    self.call_hook('before_run')

    # 如果没有达到退出条件,就一直运行工作流
    while self.epoch < self._max_epochs:
        # 遍历工作流
        for i, flow in enumerate(workflow):
            # 模式,和当前工作流需要运行的 epoch 次数
            mode, epochs = flow
            epoch_runner = getattr(self, mode)
            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break
                # 开始一个 epoch 的迭代
                epoch_runner(data_loaders[i], **kwargs)
    time.sleep(1)  # wait for some hooks like loggers to finish

    # 调用注册到 runner 中的所有 hook 的 after_run 方法,表示结束 run 后
    self.call_hook('after_run')

La méthode d'exécution définit le processus général de changement de flux de travail, et l'achèvement réel d'un flux de travail d'époque consiste à appeler la fonction de flux de travail. Actuellement, deux workflows, train et val, sont pris en charge, donc epoch_runner(data_loaders[i], **kwargs) l'appel est en fait la méthode train ou val :

# train 和 val 方法逻辑非常相似
def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True)
        self.call_hook('after_train_iter')
        self._iter += 1
    self.call_hook('after_train_epoch')
    self._epoch += 1

@torch.no_grad()
def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_val_iter')
        self.run_iter(data_batch, train_mode=False)
        self.call_hook('after_val_iter')
    self.call_hook('after_val_epoch')

La logique ci-dessus consiste à traverser le data_loader, puis à effectuer une formation ou une vérification itérative au niveau du lot, ce qui est relativement facile à comprendre. Pour effectuer réellement la formation ou la vérification d'un lot, il faut appeler self.run_iter :

# 简化逻辑
def run_iter(self, data_batch, train_mode, **kwargs):
    # 调用 model 自身的 train_step 或者 val_step 方法
    if train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer,
                                        **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])

    self.outputs = outputs

(2) Exécution d'IterBasedRunner

Le mode IterBasedRunner utilise le nombre d'itérations comme condition de fin de boucle, sans le concept d'époque, donc la méthode d'exécution d'IterBasedRunner a été légèrement modifiée :

  • La condition de fin de workflow n'est plus epoch, mais iter
  • Les méthodes de cycle de vie de Hook n'impliquent pas d'époque, toutes sont des méthodes liées à iter

Étant donné que MMDetection utilise EpochBasedRunner au lieu d'IterBasedRunner, sa logique de code détaillée n'est plus étendue.

(3) Comparaison entre EpochBasedRunner et IterBasedRunner

En supposant que la longueur des données est de 1024, lot = 4, la longueur du chargeur de données est de 1024/4 = 256, c'est-à-dire qu'une époque correspond à 256 itérations. En mode de formation Iter, il est prévu de former 100 000 itérations. S'il est en le mode d'entraînement Epoch, puis les 100000//256=39 époques réelles ont été exécutées sur .

Guess you like

Origin blog.csdn.net/weixin_43603658/article/details/129113196