Customize a data loader with PyTorch

To customize a PyTorch data loader, define a torch.utils.data.Datasetsubclass that inherits from . This subclass needs to implement the following three methods:

  • __init__(self, ...): Initialization function, used to read data and set data transformation, etc.
  • __len__(self): Returns the length of the dataset.
  • __getitem__(self, idx): Returns idxthe data corresponding to the index.

Below is an example:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_dir):
        # 读取数据
        self.data = []
        for filename in os.listdir(data_dir):
            self.data.append(read_image(os.path.join(data_dir, filename)))

        # 设置数据的变换等

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In __init__the function, you can read the data, and set the transformation of the data, etc. In __len__the function, the length of the dataset needs to be returned. In __getitem__the function, idxthe data corresponding to the index needs to be returned.

DataLoaderThis dataset can then be loaded using :

from torch.utils.data import DataLoader

batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

When training the model, you can traverse dataloaderto get the data of each batch:

for batch_data in dataloader:
    # 处理batch_data

Guess you like

Origin blog.csdn.net/qq_59109986/article/details/129361484
Recommended