detectron2源码阅读3--利用mapper封装dataset

前言

  detectron2中对读入的数据集格式有特定要求,比如说需要将voc和coco格式数据集转成dataset_dict的形式。这么干说可能你也听不明白,还是一步一步来吧。

1、DatasetCatalog和MetadataCatalog对象

  本次分析代码主要在detectron2/data目录下。
在这里插入图片描述
  如上图,dataset里面就是coco.py/voc.py等,samplers就是为了后续dataloader生成索引的。而transfoms就是图像增强部分。比较难以理解是catlog.py。里面定义了标题所说的两个类,乍一看源码让人摸不着头脑。我这里贴一下:

class _DatasetCatalog(UserDict):

    def register(self, name, func):
        assert callable(func), "You must register a function with `DatasetCatalog.register`!"
        assert name not in self, "Dataset '{}' is already registered!".format(name)
        self[name] = func

    def get(self, name):
        try:
            f = self[name]
        except KeyError as e:
            raise KeyError(
                "Dataset '{}' is not registered! Available datasets are: {}".format(
                    name, ", ".join(list(self.keys()))
                )
            ) from e
        return f()

  很抽象,注意看下register方法,注册的是一个func函数。然后再get方法中,通过name获取上述的func并最终return f()。即执行了该函数。ok,现在找一下再哪里调用了register方法。
  在dataset目录下,有个build.py函数,我这里粘贴下:

if __name__.endswith(".builtin"):
    # Assume pre-defined datasets live in `./datasets`.
    _root = os.getenv("DETECTRON2_DATASETS", "datasets")
    register_all_coco(_root)
    register_all_lvis(_root)
    register_all_cityscapes(_root)
    register_all_cityscapes_panoptic(_root)
    register_all_pascal_voc(_root)
    register_all_ade20k(_root)
    register_all_mot(_root)
    register_all_crowdhuman(_root)

 这里看下voc数据集:

def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
    DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
    MetadataCatalog.get(name).set(
        thing_classes=list(class_names), dirname=dirname, year=year, split=split
    )

 OK,现在知道了,DatasetCatalog注册了一个name,并注册了lambda函数,该函数用来获取数据集的标注信息。也就是说,DatasetCatalog基本通过register方法注册了大多数数据集的获取标注信息的函数。
  在程序启动后,这些操作会在导包时完成,最终的DatasetCatalog的内容如下:
在这里插入图片描述
 即一个数据集名称对应一个lambda function。
 MetadataCatalog主要是存储一个数据集目录。此处不展开了,其实没细看我。

2、dataset构建

1. 通过DatasetCatalog读取数据集

  在d2构建dataloader文章中,装饰器中途开小灶生成了一个dataset_dict。本文将详细说明生成的过程。
 在data/build.py中,导入数据集信息通过下面代码:

def get_detection_dataset_dicts(
    dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None):
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]  #(coco_2017_train)

 其中,dataset_name = ‘coco_2017_train’,之后通过调用DatasetCatlog的get方法获取一个lambda的fun。如果你忘了,看下第一节。则经过这行代码dataset_dicts也就生成了{‘image_name’, bbox}。

2. mapper封装dataset_dict

 在上述获得dataset_dicts之后,接下来在build.py代码逻辑中,就是该构建mapper,然后sampler的顺序。

def _train_loader_from_config(cfg, *, mapper=None, dataset=None, sampler=None):
    if dataset is None:
        dataset = get_detection_dataset_dicts(                  # 读取dataset_dicts
            cfg.DATASETS.TRAIN,
            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
            if cfg.MODEL.KEYPOINT_ON
            else 0,
            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
        )

    if mapper is None:
        mapper = DatasetMapper(cfg, True)                      # 定义了一个mapper

    if sampler is None:                                        # 定义了一个sampler
        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
        logger = logging.getLogger(__name__)
        logger.info("Using training sampler {}".format(sampler_name))
        if sampler_name == "TrainingSampler":
            sampler = TrainingSampler(len(dataset))
        elif sampler_name == "RepeatFactorTrainingSampler":
            repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
                dataset, cfg.DATALOADER.REPEAT_THRESHOLD
            )
            sampler = RepeatFactorTrainingSampler(repeat_factors)
        else:
            raise ValueError("Unknown training sampler: {}".format(sampler_name))

    return {
    
    
        "dataset": dataset,
        "sampler": sampler,
        "mapper": mapper,
        "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
        "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
        "num_workers": cfg.DATALOADER.NUM_WORKERS,
    }

  这里贴下dataset_mapper.py中内容(代码内容比较多),作用主要是__call__方法,该函数作用见源码中注释:将dataset_dict变成d2模型可以接受的格式,在call方法中: 获取file_name --> 图像增强–>返回图像数据和annotation。

class DatasetMapper:
	"""
    The callable currently does the following:

    1. Read the image from "file_name"
    2. Applies cropping/geometric transforms to the image and annotations
    3. Prepare data and annotations to Tensor and :class:`Instances`
    """

    @configurable
    def __init__(
        self,
        is_train: bool,
        *,
        augmentations: List[Union[T.Augmentation, T.Transform]],
        image_format: str,
        use_instance_mask: bool = False,
        use_keypoint: bool = False,
        instance_mask_format: str = "polygon",
        keypoint_hflip_indices: Optional[np.ndarray] = None,
        precomputed_proposal_topk: Optional[int] = None,
        recompute_boxes: bool = False,
    ):

    def __call__(self, dataset_dict):
    ...

3. 构建dataset

 OK,有了上述的mapper和sampler,可以构建dataset的__getitem__了。调用接口在build.py中:

@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):

    if isinstance(dataset, list):
        dataset = DatasetFromList(dataset, copy=False)
    if mapper is not None:
        dataset = MapDataset(dataset, mapper)
    if sampler is None:
        sampler = TrainingSampler(len(dataset))
    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
    return build_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
    )

 也就是DatasetFromList类和MapDataset类。这两个类在data/common.py文件中:

class MapDataset(data.Dataset):
    """
    Map a function over the elements in a dataset.

    Args:
        dataset: a dataset where map function is applied.
        map_func: a callable which maps the element in dataset. map_func is
            responsible for error handling, when error happens, it needs to
            return None so the MapDataset will randomly use other
            elements from the dataset.
    """

    def __init__(self, dataset, map_func):
        self._dataset = dataset
        self._map_func = PicklableWrapper(map_func)  # wrap so that a lambda will work

        self._rng = random.Random(42)
        self._fallback_candidates = set(range(len(dataset)))

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        retry_count = 0
        cur_idx = int(idx)

        while True:
            data = self._map_func(self._dataset[cur_idx])
            if data is not None:
                self._fallback_candidates.add(cur_idx)
                return data

            # _map_func fails for this idx, use a random new index from the pool
            retry_count += 1
            self._fallback_candidates.discard(cur_idx)
            cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]

            if retry_count >= 3:
                logger = logging.getLogger(__name__)
                logger.warning(
                    "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
                        idx, retry_count
                    )
                )

 最后,有了dataset和sampler,就可以构建dataloader了。

总结

  对代码没做详细解释,主要是构建思路。后续会开构建模型和构建优化器等。

猜你喜欢

转载自blog.csdn.net/wulele2/article/details/119109539
今日推荐