Detailed explanation of Dataset and DataLoader in Pytorch's torch.utils.data

In the process of our deep learning, it is inevitable to use data sets, so how is the data set loaded into our model for training? In the past, most of our beginners must have used the code on the Internet directly, but it is still not clear what the underlying principle is. So today I will make a detailed analysis from the built-in Dataset function and the custom Dataset function.

foreword

torch.utils.datais PyTorcha module provided for processing and loading data. This module provides a set of utility classes and functions for creating, manipulating and bulk loading datasets.

Here are torch.utils.datasome commonly used classes and functions in the module:

  • Dataset: Defines an abstract dataset class, and users can build their own datasets by inheriting from this class. DatasetThe class provides two methods that must be implemented: __getitem__for accessing individual samples, and __len__for returning the size of the dataset.
  • TensorDataset: Inherited from the Datasetclass, used to pack tensor data into a dataset. It takes multiple tensors as input and determines the size of the dataset according to the size of the first input tensor.
  • DataLoader: Data loader class, used to batch load datasets. It accepts a dataset object as input and provides various data loading and preprocessing functions, such as setting batch size, multi-threaded data loading and data shuffling, etc.
  • Subset: The subset class of the dataset, which is used to select the specified samples from the dataset.
  • random_split: Randomly divide a data set into multiple subsets, you can specify the ratio of division or the size of each subset.
  • ConcatDataset: Join multiple datasets together to form a larger dataset.
  • get_worker_info: Get the process information of the current data loader.

In addition to the above classes and functions, torch.utils.datasome commonly used data preprocessing tools are also provided, such as random cropping, random rotation, standardization, etc.

Through torch.utils.datathe classes and functions provided by the module, you can easily load, process and batch load data, which facilitates model training and verification. However, the two classes we use most often are Datasetthe and DataLoaderclasses.

1. Custom Dataset class

torch.utils.data.DatasetIt is an abstract class used to represent data sets in PyTorch, and is used to define the access method and number of samples of data sets.

The Dataset class is a base class, we can create a custom dataset class by inheriting this class and implementing the following two methods:

getitem (self, index): According to the given index index, return the corresponding sample data. The index can be an integer, which means to obtain samples in order, or it can be other methods, such as obtaining samples by file name, etc.
len (self): Returns the number of samples in the dataset.

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

The above code sample mainly implements a 自定义Dataset数据集类method, which is generally defined when we need to train our own data. But generally, as deep learning beginners, we use MNIST, CIFAR-10 内置数据集, etc. At this time, we don't need to define the Dataset class ourselves. As for why, we will explain in detail below.

2、torchvision.datasets

If you want to use the built-in datasets in PyTorch, you usually torchvision.datasetsdo it through modules. torchvision.datasetsThe module provides many commonly used computer vision datasets, such as MNIST, CIFAR10, ImageNet, etc.

Here is sample code using built-in datasets:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In the above code, what we have implemented is the loading and use of a built-in MNIST (handwritten digits) dataset. As you can see, we didn't use the classes mentioned above here torch.utils.data.Dataset, why is that?

This is because in torchvision.datasetsthe module, the built-in dataset class already implements torch.utils.data.Datasetthe interface and directly returns a usable dataset object. Therefore, when using the built-in dataset, we can directly instantiate the built-in dataset class without explicitly inheriting the torch.utils.data.Datasetclass.

Implementations of built-in dataset classes such as , torchvision.datasets.MNISTalready contain the definitions of __getitem__and __len__methods, which allow us to obtain samples and determine the size of the dataset directly from the built-in dataset object. In this way, when we use the built-in dataset, we can directly pass the built-in dataset object to torch.utils.data.DataLoaderfor data loading and batch processing.

Behind the built-in datasets, they are still torch.utils.data.Datasetimplemented based on classes, just for convenience and to provide more functions, PyTorch encapsulates these commonly used datasets into built-in dataset classes.

To this end, I went to the pytorch official website to check the loading code of the built-in dataset, as shown in the figure below:
insert image description here
It can be seen that the Dataset dataset class is indeed built-in.

3、DataLoader

torch.utils.data.DataLoaderIt is a tool class for batch loading data in PyTorch. It accepts a dataset object (such torch.utils.data.Datasetas a subclass of ) and provides various functions, such as data loading, batch processing, data shuffling, etc.

The following are torch.utils.data.DataLoadercommonly used parameters and functions of :

  • dataset: Dataset object, which can be torch.utils.data.Dataseta subclass object of .
  • batch_size: number of samples per batch, default is 1.
  • shuffle: Whether to shuffle the data, the default is False. Data is shuffled every epoch.
  • num_workers: How many child processes are used to load data, the default is 0, which means loading data in the main process. In fact, it is set to 0 in the Windows system, but it can be set to a number greater than 0 in Linux.
  • collate_fn: Function to process each sample before returning the batch data. If yes None, use torch.utils.data._utils.collate.default_collatethe function for processing by default.
  • drop_last: Whether to discard the data whose last sample size is less than one batch, the default is False.
  • pin_memory: Whether to store the loaded data in the fixed memory corresponding to CUDA, the default is False.
  • prefetch_factor: Prefetch factor, used to prefetch data to the device, the default is 2.
  • persistent_workers: if true True, use a persistent subprocess for data loading every epoch, default is False.

The sample code is as follows:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transformsmodule is a functional module for image data preprocessing in PyTorch. It provides a series of transformation functions for performing various common data transformation and augmentation operations when loading, training or inferring image data. The following are detailed explanations of some commonly used conversion functions:

  1. Resize: Resize the image

    • Resize(size): Resizes the image to the given dimensions. Can accept an integer as the size of the shorter side, or a tuple or list as the target size of the image.
  2. ToTensor: Convert an image to a tensor

    • ToTensor(): Converts an image to a tensor, mapping pixel values ​​ranging from 0-255 to 0-1. Suitable for passing image data to deep learning models.
  3. Normalize: Normalize image data

    • Normalize(mean, std): Normalize the image data. The mean and std passed in are the mean and standard deviation for pixel value normalization. It should be noted that mean and std need to correspond to the dataset used before.
  4. RandomHorizontalFlip: random horizontal flip image

    • RandomHorizontalFlip(p=0.5): Randomly flips the image horizontally with a given probability. Probability p controls the probability of flipping and defaults to 0.5.
  5. RandomCrop: Randomly crop an image

    • RandomCrop(size, padding=None): Randomly crop an image to a given size. A tuple or integer can be provided as target size and optionally padding value.
  6. ColorJitter: Color Jitter

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0): Randomly adjust the Brightness, Contrast, Saturation and Hue of the image. The appearance of the image can be adjusted by setting different parameters.

When using, we often use transforms.Composeto combine these data processing operations. When using, just call the combination directly.

The sample code is as follows:

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

5. Definition of Dataset dataset class in image classification

Take the eye disease dataset as an example (for details, see the basic case of deep learning practice - Convolutional Neural Network (CNN) Eye Disease Recognition Based on SqueezeNet|Example 1 ), in which we generated the train after labeling the dataset. txt and valid.txt files, there are two columns in this file, the first column is the path of the data set, and the second column is the label (that is, the category) of the data set, as follows: At this time,
insert image description here
we can define our own data Set reading class, the specific code is as follows:

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)


class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]

        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs_info)

After defining our own dataset reading class, we can pass in our txt file to preprocess and read the dataset. In our custom dataset class, the three most important methods are __init__(), getitem () and __len__(), all of which are indispensable. At the same time, the data enhancement operation of transforms is not necessary. This is just a way to improve the performance of the model, but our current model training process generally adds data enhancement operations.

# 加载训练集和验证集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

Above, we loaded our train.txt file and valid.txt file respectively through the custom MyDataset class (the following True parameter means that we want to enhance the data of the training set, while False means to enhance the data of the validation set). Then, we use our DataLoader to batch load the data set, and then we can directly throw the loaded data train_dl into test_dlthe model for training.


Specific examples can refer to:

Guess you like

Origin blog.csdn.net/m0_63007797/article/details/132385283