要自定义一个PyTorch的数据加载器,可以定义一个继承自torch.utils.data.Dataset
的子类。这个子类需要实现以下三个方法:
__init__(self, ...)
:初始化函数,用于读取数据和设置数据的变换等。__len__(self)
:返回数据集的长度。__getitem__(self, idx)
:返回索引idx
对应的数据。
下面是一个例子:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_dir):
# 读取数据
self.data = []
for filename in os.listdir(data_dir):
self.data.append(read_image(os.path.join(data_dir, filename)))
# 设置数据的变换等
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
在__init__
函数中,可以读取数据,并设置数据的变换等。在__len__
函数中,需要返回数据集的长度。在__getitem__
函数中,需要返回索引idx
对应的数据。
然后可以使用DataLoader
来加载这个数据集:
from torch.utils.data import DataLoader
batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
在训练模型时,可以遍历dataloader
来获取每个batch的数据:
for batch_data in dataloader:
# 处理batch_data