PyTorch を使用してデータ ローダーをカスタマイズする

PyTorch データ ローダーをカスタマイズするには、torch.utils.data.Datasetを継承するサブクラスを定義します。このサブクラスは、次の 3 つのメソッドを実装する必要があります。

  • __init__(self, ...): データの読み込みやデータ変換の設定などに使用する初期化関数。
  • __len__(self): データセットの長さを返します。
  • __getitem__(self, idx):idxインデックスに対応するデータを返します。

以下に例を示します。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_dir):
        # 读取数据
        self.data = []
        for filename in os.listdir(data_dir):
            self.data.append(read_image(os.path.join(data_dir, filename)))

        # 设置数据的变换等

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

関数では__init__データの読み込みやデータの変換などの設定を行うことができます。関数では__len__、データセットの長さを返す必要があります。__getitem__関数では、idxインデックスに対応するデータを返す必要があります。

DataLoaderこのデータセットは、次を使用してロードできます。

from torch.utils.data import DataLoader

batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

モデルをトレーニングするときに、dataloader各バッチのデータを取得するためにトラバースできます。

for batch_data in dataloader:
    # 处理batch_data

おすすめ

転載: blog.csdn.net/qq_59109986/article/details/129361484