Image segmentation suite PaddleSeg comprehensive analysis (three) DataSet code interpretation

In the yaml configuration file, the type of the train_dataset we configured is the Cityscapes type. Through the interpretation of the above Config code, we know that the Cityscapes object will be created lazily when the train_dataset property of the Config object is called for the first time.
The location of the Cityscapes class is in paddleseg/datasets/cityscapes.py. The parent class of Cityscapes is Dataset, which is located in the dataset.py file in the same directory, so I will start with the Dataset class.

First, start with the constructor of the Dataset. The constructor is relatively long and contains some judgment logic to initialize member variables:

def __init__(self,
             transforms,#图像的transform
             dataset_root,#dataset的路劲
             num_classes, #类别数量
             mode='train', # 训练模式,train、val和test
             train_path=None, #训练列表文件路径,文件中每一行第一个是样本文件,第二个是标注文件。image1.jpg ground_truth1.png
             val_path=None, #验证列表文件路径,与训练文件一致。
             test_path=None,#与训练文件一致,其中标注文件不是必须的。
             separator=' ', #指定列表文件中样本文件和训练文件的分隔符,默认是空格
             ignore_index=255, #需要忽略的类别id
             edge=False): #是否在训练时计算边缘
    #保存数据的路径
    self.dataset_root = dataset_root
    #构建数据增强对象
    self.transforms = Compose(transforms)
    #新建一个保存文件路径的空列表
    self.file_list = list()
    #将模式类型字符串转换为小写并保存为成员变量
    mode = mode.lower()
    self.mode = mode
    #保存类别数
    self.num_classes = num_classes
    #保存需要忽略的类别编号,一般都是255
    self.ignore_index = ignore_index
    #保存edge
    self.edge = edge
    
    #如果mode不在train\val\test中,需要抛出异常。
    if mode.lower() not in ['train', 'val', 'test']:
        raise ValueError(
            "mode should be 'train', 'val' or 'test', but got {}.".format(
                mode))
    #数据增强对象必须指定,如果未设置,抛出异常。
    if self.transforms is None:
        raise ValueError("`transforms` is necessary, but it is None.")
    #如果数据集路径不存在则抛出异常。
    self.dataset_root = dataset_root
    if not os.path.exists(self.dataset_root):
        raise FileNotFoundError('there is not `dataset_root`: {}.'.format(
            self.dataset_root))
    #判断各个类型的文件列表是否存在,不存在抛出异常,存在则保存到file_path变量中。
    if mode == 'train':
        if train_path is None:
            raise ValueError(
                'When `mode` is "train", `train_path` is necessary, but it is None.'
            )
        elif not os.path.exists(train_path):
            raise FileNotFoundError(
                '`train_path` is not found: {}'.format(train_path))
        else:
            file_path = train_path
    elif mode == 'val':
        if val_path is None:
            raise ValueError(
                'When `mode` is "val", `val_path` is necessary, but it is None.'
            )
        elif not os.path.exists(val_path):
            raise FileNotFoundError(
                '`val_path` is not found: {}'.format(val_path))
        else:
            file_path = val_path
    else:
        if test_path is None:
            raise ValueError(
                'When `mode` is "test", `test_path` is necessary, but it is None.'
            )
        elif not os.path.exists(test_path):
            raise FileNotFoundError(
                '`test_path` is not found: {}'.format(test_path))
        else:
            file_path = test_path
    #打开列表文件,文件包含若干行,数量与数据集样本数量相同,训练集(train)和验证集(val)列表包含样本路径和标签文件路径。
    #测试集则只包含样本路径。
    with open(file_path, 'r') as f:
        #遍历列表文件中的每一行。
        for line in f:
            #分离样本路径和标签路径。
            items = line.strip().split(separator)
            #如果在训练集和验证集不包含样本路径和标签路径则抛出异常。
            if len(items) != 2:
                if mode == 'train' or mode == 'val':
                    raise ValueError(
                        "File list format incorrect! In training or evaluation task it should be"
                        " image_name{}label_name\\n".format(separator))
                image_path = os.path.join(self.dataset_root, items[0])
                label_path = None
            else:
                #拼接样本完整路径和标签完整路径
                image_path = os.path.join(self.dataset_root, items[0])
                label_path = os.path.join(self.dataset_root, items[1])
            #将样本路径和标签路径保存在列表中。
            self.file_list.append([image_path, label_path])

Where this __getitem__ method is defined in a class, then its instance object (assumed to be p), can take the value of p[key] like this, and when the instance object performs p[key] operation, it will call the method in the class_ _getitem__.
In this way, the object can be searched by the subscript. The object can become an iterable object.

The following is an explanation of how to return samples and labels through file_list in the Dataset class.

def __getitem__(self, idx):
	  #通过idx下标,在file_list里获取样本图片路径和标签图片路径。
      image_path, label_path = self.file_list[idx]
      #如果是测试模式则返回图片ndarray类型的数据。在transforms中,包含了图片的读取和预处理,不同模式的dataset类的transforms对象是不同的。
      if self.mode == 'test':
            im, _ = self.transforms(im=image_path)
            im = im[np.newaxis, ...]
            return im, image_path
      #如果是训练或者验证模式还需要返回样本图片的和标签图片的ndarray的数据类型。
      elif self.mode == 'val':
            im, _ = self.transforms(im=image_path)
            label = np.asarray(Image.open(label_path))
            label = label[np.newaxis, :, :]
            return im, label
      else:
            im, label = self.transforms(im=image_path, label=label_path)
            if self.edge:
                edge_mask = F.mask_to_binary_edge(
                    label, radius=2, num_classes=self.num_classes)
                return im, label, edge_mask
            else:
                return im, label

The __len__ method is defined in the class, you can use the len function to get the length, save the file list in the Dataset class, so you need to use the len function to get the number of samples in the data set, so you need to implement __len__ in the Dataset class method.

def __len__(self):
      #该方法直接返回file_list列表的长度即可。
      return len(self.file_list)

The above interpretation of the implementation of the Dataset class, let's take a look at an actual data set Cityscapes.
The Cityscapes class is defined in the dygraph/paddleseg/datasets/cityscapes.py file. This class is a subclass of Dataset. Naturally, it inherits the __getitem__ and __len__ methods,
and the code in these two methods is reusable. __Getitem__ contains the preprocessing of sample pictures and label pictures. No matter what the data set is in this part, the operation should be similar. In time, the preprocessing can also be processed by passing the
transforms object, so in the Cityscapes class In, we only care about the constructor.

def __init__(self, transforms, dataset_root, mode='train', edge=False):
      #这部分与Dataset类基本一致,保存一些成员变量,不过这里面指定了该数据集共有19类,同时直接指定了ignore_index为255.
      self.dataset_root = dataset_root
      self.transforms = Compose(transforms)
      self.file_list = list()
      mode = mode.lower()
      self.mode = mode
      self.num_classes = 19
      self.ignore_index = 255
      self.edge = edge

      if mode not in ['train', 'val', 'test']:
          raise ValueError(
              "mode should be 'train', 'val' or 'test', but got {}.".format(
                  mode))

      if self.transforms is None:
          raise ValueError("`transforms` is necessary, but it is None.")
      #由于不同的数据集文件组织结构会不同,在Cityscapes数据集中样本图片和标签图片分别保存在leftImg8bit和gtFine路径下。
      img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
      label_dir = os.path.join(self.dataset_root, 'gtFine')
      if self.dataset_root is None or not os.path.isdir(
              self.dataset_root) or not os.path.isdir(
                  img_dir) or not os.path.isdir(label_dir):
          raise ValueError(
              "The dataset is not Found or the folder structure is nonconfoumance."
          )
      #这里没有使用读取列表文件的方式获取样本图片列表和标签图片列表,而是通过glob方法使用正则化的方法匹配对应的文件来获取标签图片路径。
      label_files = sorted(
          glob.glob(
              os.path.join(label_dir, mode, '*',
                           '*_gtFine_labelTrainIds.png')))
      #跟上面一样获取样本图片路径列表。
      img_files = sorted(
          glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
      #构建文件列表,每一个元素,是包含两个元素的列表,形式为[样本图片路径,标签图片路径],供父类的__getitem__调用去预处理图片数据。
      self.file_list = [[
          img_path, label_path
      ] for img_path, label_path in zip(img_files, label_files)]

The main code of the data set part is basically interpreted. Here we understand that the Dataset class has been provided as the base class in the PaddleSeg suite, so if we want to add a new data set, we can inherit the
Dataset class, and then implement the __init__ construction method by ourselves That is, refer to the Cityscapes class implementation.

PaddleSeg warehouse address: https://github.com/PaddlePaddle/PaddleSeg

Guess you like

Origin blog.csdn.net/txyugood/article/details/111031965