使用PyTorch自定义一个数据加载器

要自定义一个PyTorch的数据加载器,可以定义一个继承自torch.utils.data.Dataset的子类。这个子类需要实现以下三个方法:

  • __init__(self, ...):初始化函数,用于读取数据和设置数据的变换等。
  • __len__(self):返回数据集的长度。
  • __getitem__(self, idx):返回索引idx对应的数据。

下面是一个例子:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_dir):
        # 读取数据
        self.data = []
        for filename in os.listdir(data_dir):
            self.data.append(read_image(os.path.join(data_dir, filename)))

        # 设置数据的变换等

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

__init__函数中,可以读取数据,并设置数据的变换等。在__len__函数中,需要返回数据集的长度。在__getitem__函数中,需要返回索引idx对应的数据。

然后可以使用DataLoader来加载这个数据集:

from torch.utils.data import DataLoader

batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

在训练模型时,可以遍历dataloader来获取每个batch的数据:

for batch_data in dataloader:
    # 处理batch_data

猜你喜欢

转载自blog.csdn.net/qq_59109986/article/details/129361484