PyTorch データ ローダーをカスタマイズするには、torch.utils.data.Dataset
を継承するサブクラスを定義します。このサブクラスは、次の 3 つのメソッドを実装する必要があります。
__init__(self, ...)
: データの読み込みやデータ変換の設定などに使用する初期化関数。__len__(self)
: データセットの長さを返します。__getitem__(self, idx)
:idx
インデックスに対応するデータを返します。
以下に例を示します。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_dir):
# 读取数据
self.data = []
for filename in os.listdir(data_dir):
self.data.append(read_image(os.path.join(data_dir, filename)))
# 设置数据的变换等
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
関数では__init__
データの読み込みやデータの変換などの設定を行うことができます。関数では__len__
、データセットの長さを返す必要があります。__getitem__
関数では、idx
インデックスに対応するデータを返す必要があります。
DataLoader
このデータセットは、次を使用してロードできます。
from torch.utils.data import DataLoader
batch_size = 32
dataset = MyDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
モデルをトレーニングするときに、dataloader
各バッチのデータを取得するためにトラバースできます。
for batch_data in dataloader:
# 处理batch_data