Detectron2 source code to read notes - (c) Dataset pipeline

Principle steps to build data_loader

# engine/default.py
from detectron2.data import (
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        ...    
    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        """
        return build_detection_train_loader(cfg)

Function call relationship in the following figure:

Combine the previous two articles of content can be seen detectron2 are when building model, optimizer and data_loader in the corresponding build.pyimplementation file inside. We look at build_detection_train_loaderare (corresponding to the figure above is how to define the purple box portion (the bottom-up order)):


def build_detection_train_loader(cfg, mapper=None):
    """
    A data loader is created by the following steps:

    1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
    2. Start workers to work on the dicts. Each worker will:
      * Map each metadata dict into another format to be consumed by the model.
      * Batch them by simply putting dicts into a list.
    The batched ``list[mapped_dict]`` is what this dataloader will return.

    Args:
        cfg (CfgNode): the config
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            By default it will be `DatasetMapper(cfg, True)`.

    Returns:
        a torch DataLoader object
    """
    # 获得dataset_dicts
    dataset_dicts = get_detection_dataset_dicts(
        cfg.DATASETS.TRAIN,
        filter_empty=True,
        min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
        if cfg.MODEL.KEYPOINT_ON
        else 0,
        proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
    )
    
    # 将dataset_dicts转化成torch.utils.data.Dataset
    dataset = DatasetFromList(dataset_dicts, copy=False)

    # 进一步转化成MapDataset,每次读取数据时都会调用mapper来对dict进行解析
    if mapper is None:
        mapper = DatasetMapper(cfg, True)
    dataset = MapDataset(dataset, mapper)
    
    # 采样器
    sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
    if sampler_name == "TrainingSampler":
        sampler = samplers.TrainingSampler(len(dataset))
        ...
    batch_sampler = build_batch_data_sampler(
        sampler, images_per_worker, group_bin_edges, aspect_ratios
    )
    
    # 数据迭代器 data_loader
    data_loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        batch_sampler=batch_sampler,
        collate_fn=trivial_batch_collator,
        worker_init_fn=worker_init_reset_seed,
    )
    return data_loader

From the above source code can be seen a total of five steps, we are only three front portion is described in detail, later data_loader sample and can be found in a text to understand the relationship between Pytorch DataLoader, the DataSet, Sampler .

obtaindataset_dicts

get_detection_dataset_dicts(dataset_names)An important parameter is the function needs to be passed dataset_names, this parameter is actually a string that specifies the name of the data set. By this string, the function calls data/catalog.pyto DatasetCatalogbe resolved to the class contains a data dictionary information.

Analytical principle is: DatasetCatalogThere is a dictionary _REGISTERED, for example, the default has been registered good coco,vocinformation on these data sets. If you want to use your own set of data, then you need to you need to define your data set name and define a function before the beginning (This function does not need to pass parameters, and finally returns a dict, dict that contains your data set information), chestnut give:

from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
    ...
    return dict

DatasetCatalog.register(my_dataset_name, get_dicts)

Of course, if your data set is already COCO format, then you can register using the following method:

from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")

Also note that a data set can actually be defined by two classes, one is introduced in front of DatasetCatalog, the other MetadataCatalog.

MetadataCatalog的作用是记录数据集的一些特征,这样我们就可以很方便的在整个代码中获取数据集的特征信息。在注册DatasetCatalog后,我们可以按如下栗子对MetadataCatalog进行注册并定义我们后面可能会用到的属性特征:

from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]

# 也可以这样
MetadataCatalog.get("my_dataset").set("thing_classes",["person", "dog"])

注意:如果你的数据集名字未注册过,MetadataCatalog.get会自动进行注册,然后会自动设置你所设定的属性值。

其实MetadataCatalog还有其他的特征属性可以设置,如stuff_classes,stuff_colors等等。你可能会好奇thing_classesstuff_classes有什么区别,区别如下:

  • 抽象解释:thing_classes用于指定instance-level任务,stuff_classes用于semantic segmentation任务。
  • 具体解释:像椅子,书这种可数的东西,就可以理解成thing,所以用于instance-level;而雪、天空这种不可数的就理解成stuff,所以用于semantic segmentation。参考On Seeing Stuff: The Perception of Materials by Humans and Machines

最后,get_detection_dataset_dicts会返回一个包含若干个dict的list,之所以是list是因为参数dataset_names也是一个list,这样我们就可以制定多个names来同时对数据进行读取。

解析成DatasetFromList

DatasetFromList(dataset_dict)函数定义在detectron2/data/common.py中,它其实就是一个torch.utils.data.Dataset类,其源码如下

class DatasetFromList(data.Dataset):
    """
    Wrap a list to a torch Dataset. It produces elements of the list as data.
    """

    def __init__(self, lst: list, copy: bool = True):
        """
        Args:
            lst (list): a list which contains elements to produce.
            copy (bool): whether to deepcopy the element when producing it,
                so that the result can be modified in place without affecting the
                source in the list.
        """
        self._lst = lst
        self._copy = copy

    def __len__(self):
        return len(self._lst)

    def __getitem__(self, idx):
        if self._copy:
            return copy.deepcopy(self._lst[idx])
        else:
            return self._lst[idx]

这个很简单就不加赘述了

DatsetFromList转化成MapDataset

其实DatsetFromListMapDataset都是torch.utils.data.Dataset的子类,那他们的区别是什么呢?很简单,区别就是后者使用了mapper

在解释mapper是什么之前我们首先要知道的是,在detectron2中,一张图片对应的是一个dict,那么整个数据集就是list[dict]。之后我们再看DatsetFromList,它的__getitem__函数非常简单,它只是简单粗暴地就返回了指定idx的元素。显然这样是不行的,因为在把数据扔给模型训练之前我们肯定还要对数据做一定的处理,而这个工作就是由mapper来做的,默认情况下使用的是detectron2/data/dataset_mapper.py中定义的DatasetMapper,如果你需要自定义一个mapper也可以参考这个写。

DatasetMapper(cfg, is_train=True)

我们继续了解一下DatasetMapper的实现原理,首先看一下官方给的定义:

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

简单概括就是这个类是可调用的(callable),所以在下面的源码中可以看到定义了__call__方法。

该类主要做了这三件事:

The callable currently does the following:

  1. Read the image from "file_name"
  2. Applies cropping/geometric transforms to the image and annotations
  3. Prepare data and annotations to Tensor and :class:Instances

其源码如下(有删减):

class DatasetMapper:
    def __init__(self, cfg, is_train=True):
        # 读取cfg的参数
        ...

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        
        # 1. 读取图像数据
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        
        # 2. 对image和box等做Transformation
        if "annotations" not in dataset_dict:
            image, transforms = T.apply_transform_gens(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
            )
        else:
            ...
            image, transforms = T.apply_transform_gens(self.tfm_gens, image)
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w
        
        # 3.将数据转化成tensor格式
        dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
        ...

        return dataset_dict

MapDataset

class MapDataset(data.Dataset):
    def __init__(self, dataset, map_func):
        self._dataset = dataset
        self._map_func = PicklableWrapper(map_func)  # wrap so that a lambda will work

        self._rng = random.Random(42)
        self._fallback_candidates = set(range(len(dataset)))

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        retry_count = 0
        cur_idx = int(idx)

        while True:
            data = self._map_func(self._dataset[cur_idx])
            if data is not None:
                self._fallback_candidates.add(cur_idx)
                return data

            # _map_func fails for this idx, use a random new index from the pool
            retry_count += 1
            self._fallback_candidates.discard(cur_idx)
            cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]

            if retry_count >= 3:
                logger = logging.getLogger(__name__)
                logger.warning(
                    "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
                        idx, retry_count
                    )
                )
  • self._fallback_candidates是一个set,它的特点是其中的元素是独一无二的,定义这个的作用是记录可正常读取的数据索引,因为有的数据可能无法正常读取,所以这个时候我们就可以把这个坏数据的索引从_fallback_candidates中剔除,并随机采样一个索引来读取数据。
  • __getitem__中的逻辑就是首先读取指定索引的数据,如果正常读取就把该所索引值加入到_fallback_candidates中去;反之,如果数据无法读取,则将对应索引值删除,并随机采样一个数据,并且尝试3次,若3次后都无法正常读取数据,则报错,但是好像也没有退出程序,而是继续读数据,可能是以为总有能正常读取的数据吧hhh。


MARSGGBO原创

如有意合作,欢迎私戳

邮箱:[email protected]


2019-10-23 13:37:13



Guess you like

Origin www.cnblogs.com/marsggbo/p/11727556.html