detectron2源码阅读2---使用configurable装饰器来构建dataloader

前言

 本篇主要讲解detectron2是如何读取数据集并用dataloader进行包装的。一个目标检测模型往往包含众多参数,那么如何提取出对应数据集的参数呢?detectron2设计了configuable装饰器。因此,本文主要分析下读取过程。细节后续有空在写。

1、从train.py文件debug开始

  在介绍detectron2的engine中,默认的训练器是engine/defaults.py文件中的 类class DefaultTrainer(TrainerBase)。在其中初始化类中,有一个构建读取数据集的接口:

data_loader = self.build_train_loader(cfg)

而build_train_loader继续下挖一层:

    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg)

OK,继续挖…

@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):

okay,终于遇到今天的难点了,我们发现,该函数用@configurable装饰器进行了装饰,而且该装饰器后边还跟着一个from_config参数。现在不理解装饰器没关系,我们现在只需要知道被装饰器包装的函数: build_detection_train_loader记为orig_func. 此处易于后续理解,先记住!
装饰器的执行顺序就是先执行装饰的部分,然后在执行被装饰的函数orig_func。所以,继续debug,你会进入到configurable装饰器里。而且,记住configurable的第一个参数是"from_config = _train_loader_from_config"。这里,我可以先告诉你:_train_loader_from_config是一个函数,你现在不需要知道其具体内容,你现在只需要把它看成from_config 即可。

2、函数装饰器configurable

  configurable实现在detectron2/config/config.py文件中。我这里先贴下其部分源码:

def configurable(init_func=None, *, from_config=None):      # * 后面参数必须明示写出来

    if init_func is not None:          # 若指定了init_func则执行if条件语句
        @functools.wraps(init_func)
        def wrapped(self, *args, **kwargs):
            if _called_with_cfg(*args, **kwargs):
                explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
                init_func(self, **explicit_args)
            else:
                init_func(self, *args, **kwargs)
        return wrapped

    else:                             # 若没有指定,则执行else语句
        def wrapper(orig_func):       # 此时orig_func就是指被装饰的原始函数

            @functools.wraps(orig_func)
            def wrapped(*args, **kwargs):
                if _called_with_cfg(*args, **kwargs):
                    explicit_args = _get_args_from_config(from_config, *args, **kwargs)
                    return orig_func(**explicit_args)           
                else:
                    return orig_func(*args, **kwargs)
            return wrapped
        return wrapper

  此处结构是if -else, 区别就是指定了init_func参数。第一部分我们只指定了from_config参数,因此,我们只需要看else部分即可。此时你发现了代码结构:
def warpper(orig_func)
_get_args_from_config(from_config)
return wrapper
  其本质就是funA = funB(funA)。简单来说就是:此处funA是orig_func,之后在funB函数wrapped函数内内拓展了funA的一部分功能,比如该函数内部中间开小灶,调用了一个_get_args_from_config(from_config)函数并且返回了一个orig_func(**args),即返回funA。
  所以,到此你就可以猜出来:构建dataloader首先开小灶调用from_config函数构建了一个dataset类,之后在通过orig_func构建dataloader等。

3、合并

3.1 from_config函数

  第二部分你了解了装饰器内容,结合第一部分你记住的:orig_func 是build_detection_train_loade, from_config是_train_loader_from_config。这里看下_train_loader_from_config的代码:

def _train_loader_from_config(cfg, *, mapper=None, dataset=None, sampler=None):
    if dataset is None:
        dataset = get_detection_dataset_dicts(                  # 读取数据集
            cfg.DATASETS.TRAIN,
            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
            if cfg.MODEL.KEYPOINT_ON
            else 0,
            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
        )

    if mapper is None:
        mapper = DatasetMapper(cfg, True)

    if sampler is None:
        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
        logger = logging.getLogger(__name__)
        logger.info("Using training sampler {}".format(sampler_name))
        if sampler_name == "TrainingSampler":
            sampler = TrainingSampler(len(dataset))
        elif sampler_name == "RepeatFactorTrainingSampler":
            repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
                dataset, cfg.DATALOADER.REPEAT_THRESHOLD
            )
            sampler = RepeatFactorTrainingSampler(repeat_factors)
        else:
            raise ValueError("Unknown training sampler: {}".format(sampler_name))

    return {
    
    
        "dataset": dataset,
        "sampler": sampler,
        "mapper": mapper,
        "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
        "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
        "num_workers": cfg.DATALOADER.NUM_WORKERS,
    }

  该函数主要作用就是构建了dataset和sampler。即开小灶程序拓展的功能。

3.2 总的程序流程

 现在走一下调用流程,首先开小灶,程序执行了explicit_args = _get_args_from_config(from_config, *args, **kwargs)。我这里贴下代码:

def _get_args_from_config(from_config_func):      # !!!!!!!!!!!此处第一个参数即from_config
    if support_var_arg:  # forward all arguments to from_config, if from_config accepts them
        ret = from_config_func(*args, **kwargs)
    else:
        # forward supported arguments to from_config
        supported_arg_names = set(signature.parameters.keys())
        extra_kwargs = {
    
    }
        for name in list(kwargs.keys()):
            if name not in supported_arg_names:
                extra_kwargs[name] = kwargs.pop(name)
        ret = from_config_func(*args, **kwargs)            # !!!!!!!!!!!!!!!!!!!
        # forward the other arguments to __init__
        ret.update(extra_kwargs)
    return ret

  主要看代码中我加感叹号的部分,实质上该函数就是调用了3.1节中的函数,即开小灶去构建了一个dataset和sampler;之后在调用orig_func完成dataloader构建:

@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):

    if isinstance(dataset, list):
        dataset = DatasetFromList(dataset, copy=False)
    if mapper is not None:
        dataset = MapDataset(dataset, mapper)
    if sampler is None:
        sampler = TrainingSampler(len(dataset))
    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
    return build_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
    )

总结

  detectron2中用到configurable装饰器的地方不少。就不一一列举了,后续会介绍如何封装dataset,即如何开小灶的。

Guess you like

Origin blog.csdn.net/wulele2/article/details/119081975