In-depth explanation of Dataset and DataLoader

Dataset & DataLoader

1. Official Explanation (Google Translate):
Code that processes data samples can become messy and difficult to maintain; ideally, we want our dataset code to be separated from our model training code for better readability and modular.
PyTorch provides two data primitives: torch.utils.data.DataLoaderand torch.utils.data.Datasetwhich allow us to use preloaded datasets as well as our own data. The Dataset stores the samples and their corresponding labels , and the DataLoader wraps an iterable object Dataset around for easy access to the samples.
2. Dataset
is a template for all datasets used by all developers for training and testing.
Dataset defines the content of the data set, which is equivalent to a list-like data structure with a certain length, and can use the index to obtain the elements in the data set.
DataLoader defines a method for loading datasets by batch. It is an iterable object that implements the __iter__ method, and each iteration outputs a batch of data.
3. DataLoader
DataLoader can control the size of the batch, the sampling method of elements in the batch, and the method of sorting the batch results into the input form required by the model, and can use multiple processes to read data.
In most cases, we only need to implement the __len__ method and __getitem__ method of Dataset, you can easily build your own dataset and load it with the default data pipeline.

1. Custom Dataset

The custom Dataset class needs to inherit the official pytorch DataSet class and must also implement three functions: __init__ , __len__ and __getitem__ .
init: initialization (generally need to pass in the data set file path , which path to save the file to , preprocessing function )
len: return the size of the data set
getitem: return the features and labels of the sample according to the index .

import os.path

import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image


class MyImageDataset(Dataset):
    def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
        # annotations_file:文件路径
        # data_dir: 将文件保存到哪个路径
        self.data_label = pd.read_csv(annotations_file)
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        # 返回数据集总的大小
        return len(self.data_label)

    def __getitem__(self, item):
        data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
        image = read_image(data_name)
        # 对特征进行预处理
        label = self.data_label.iloc[item, 1]
        if self.transform:
            image = self.transform(image)
        # 对标签进行预处理
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In fact, we only need to modify the four parameters of annotations_file, data_dir, transform (feature preprocessing), target_transform (label preprocessing) .
Dataset processes only one sample at a time, and returns a feature and the label corresponding to the feature

2. Use DataLoaders to prepare data for training

Retrieve the Dataset features of our dataset and label one sample at a time. When training a model, we typically want to pass samples in " mini-batches ", reshuffle at each epoch (how many times per iteration) to reduce model overfitting, and use Python multiprocessing to speed up data retrieval.

batch_size: The number of samples selected for one training
shuffle=True: Randomly arrange the data after each training cycle

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

3. Iterative data

We have loaded this dataset into the DataLoader and can iterate over the dataset as needed . Each iteration below returns a batch of train_features and train_labels (with batch_size=64 features and labels, respectively).
The iter() method gets an iterator.
The next() method gets the features and labels in turn.

train_features, train_labels = next(iter(train_dataloader))

Guess you like

Origin blog.csdn.net/weixin_51799151/article/details/123968508