To customize a PyTorch data loader, define a torch.utils.data.Dataset
subclass that inherits from . This subclass needs to implement the following three methods:
__init__(self, ...)
: Initialization function, used to read data and set data transformation, etc.__len__(self)
: Returns the length of the dataset.__getitem__(self, idx)
: Returnsidx
the data corresponding to the index.
Below is an example:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_dir):
# 读取数据
self.data = []
for filename in os.listdir(data_dir):
self.data.append(read_image(os.path.join(data_dir, filename)))
# 设置数据的变换等
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
In __init__
the function, you can read the data, and set the transformation of the data, etc. In __len__
the function, the length of the dataset needs to be returned. In __getitem__
the function, idx
the data corresponding to the index needs to be returned.
DataLoader
This dataset can then be loaded using :
from torch.utils.data import DataLoader
batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
When training the model, you can traverse dataloader
to get the data of each batch:
for batch_data in dataloader:
# 处理batch_data