Article directory
Dataset & DataLoader
1. Official Explanation (Google Translate):
Code that processes data samples can become messy and difficult to maintain; ideally, we want our dataset code to be separated from our model training code for better readability and modular.
PyTorch provides two data primitives: torch.utils.data.DataLoader
and torch.utils.data.Dataset
which allow us to use preloaded datasets as well as our own data. The Dataset stores the samples and their corresponding labels , and the DataLoader wraps an iterable object Dataset around for easy access to the samples.
2. Dataset
is a template for all datasets used by all developers for training and testing.
Dataset defines the content of the data set, which is equivalent to a list-like data structure with a certain length, and can use the index to obtain the elements in the data set.
DataLoader defines a method for loading datasets by batch. It is an iterable object that implements the __iter__ method, and each iteration outputs a batch of data.
3. DataLoader
DataLoader can control the size of the batch, the sampling method of elements in the batch, and the method of sorting the batch results into the input form required by the model, and can use multiple processes to read data.
In most cases, we only need to implement the __len__ method and __getitem__ method of Dataset, you can easily build your own dataset and load it with the default data pipeline.
1. Custom Dataset
The custom Dataset class needs to inherit the official pytorch DataSet class and must also implement three functions: __init__ , __len__ and __getitem__ .
init: initialization (generally need to pass in the data set file path , which path to save the file to , preprocessing function )
len: return the size of the data set
getitem: return the features and labels of the sample according to the index .
import os.path
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
class MyImageDataset(Dataset):
def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
# annotations_file:文件路径
# data_dir: 将文件保存到哪个路径
self.data_label = pd.read_csv(annotations_file)
self.data_dir = data_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
# 返回数据集总的大小
return len(self.data_label)
def __getitem__(self, item):
data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
image = read_image(data_name)
# 对特征进行预处理
label = self.data_label.iloc[item, 1]
if self.transform:
image = self.transform(image)
# 对标签进行预处理
if self.target_transform:
label = self.target_transform(label)
return image, label
In fact, we only need to modify the four parameters of annotations_file, data_dir, transform (feature preprocessing), target_transform (label preprocessing) .
Dataset processes only one sample at a time, and returns a feature and the label corresponding to the feature
2. Use DataLoaders to prepare data for training
Retrieve the Dataset features of our dataset and label one sample at a time. When training a model, we typically want to pass samples in " mini-batches ", reshuffle at each epoch (how many times per iteration) to reduce model overfitting, and use Python multiprocessing to speed up data retrieval.
batch_size: The number of samples selected for one training
shuffle=True: Randomly arrange the data after each training cycle
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
3. Iterative data
We have loaded this dataset into the DataLoader and can iterate over the dataset as needed . Each iteration below returns a batch of train_features and train_labels (with batch_size=64 features and labels, respectively).
The iter() method gets an iterator.
The next() method gets the features and labels in turn.
train_features, train_labels = next(iter(train_dataloader))