Paddle: Load custom data sets

Paddle provides two ways to load datasets:

1: Load built-in data

2: Load custom data

1: Load built-in data

The Paddle Framework has some built-in classic data sets in  the paddle.vision.datasets  and  paddle.text  directories that can be called directly. You can view the built-in data sets in the Paddle Framework through the following code.

import paddle
print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)

For specific usage, please refer to the official documentation: Dataset Definition and Loading-Usage Documentation-PaddlePaddle Deep Learning Platform

2: Load custom data 

In actual scenarios, you generally need to use your own data to define a data set. In this case, you can   implement a custom data set through the paddle.io.Dataset base class.

A subclass can be constructed to inherit from  paddle.io.Dataset and implement the following three functions:

(Does it look familiar? It cannot be said to be exactly the same as Pytorch, it can only be said to be exactly the same. The purpose is to reduce the difficulty of transfer learning)

1、__init__:完成数据集初始化操作,将磁盘中的样本文件路径和对应标签映射到一个列表中。

2、__getitem__:定义指定索引(index)时如何获取样本数据,最终返回对应 index 的单条数据(样本数据、对应的标签)。

3、__len__:返回数据集的样本总数。

Go directly to the sample code:

import os
import cv2
import numpy as np
from paddle.io import Dataset
from paddle.vision import transforms as T


'''
paddle-API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html
'''

class ListDataset(Dataset):
    def __init__(self, list_file, mode='train'):
        if mode == 'train':
            print("Loading train data ......")
        else:
            print("Loading test data ......")
        # mode
        self.mode = mode
        # load list
        self.data_list = []
        with open(list_file, "r") as f:
            self.data_list = f.readlines()
        # define img transform
        self.transform_train = T.Compose([
            T.Resize((128, 64), interpolation='nearest'),
            T.ContrastTransform(0.2),
            T.BrightnessTransform(0.2),
            T.RandomHorizontalFlip(0.5),
            T.RandomRotation(15),
            T.Transpose(),
            T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])
        self.transfrom_eval = T.Compose([
            T.Resize((128, 64), interpolation='nearest'),
            T.Transpose(),
            T.Normalize(mean=[127.5, 127.5, 127.5],  data_format='CHW', std=[127.5, 127.5, 127.5],  to_rgb=True)])

    def __getitem__(self, index):
        line_info = self.data_list[index].strip().split(' ')
        img_bgr = cv2.imread(line_info[0])
        img_label = [int(i) for i in line_info[1:]]
        if self.mode == 'train':
            img = self.transform_train(img_bgr)
        else:
            img = self.transfrom_eval(img_bgr)

        return img, img_label

    def __len__(self):
        return len(self.data_list)

If you encounter an unclear API: directly read the official documentation. If it is still unclear, then read the corresponding pytorch documentation. Both are basically the same.

paddle-API documentation: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html

Guess you like

Origin blog.csdn.net/lilai619/article/details/128579889