Pytorchのtorch.utils.dataのDatasetとDataLoaderの詳細な説明

深層学習のプロセスではデータセットの使用が避けられませんが、そのデータセットはトレーニング用のモデルにどのように読み込まれるのでしょうか? これまで、初心者のほとんどはインターネット上のコードを直接使用していたはずですが、その基礎となる原理が何であるかは依然として明らかではありません。そこで今日は、組み込みの Dataset 関数とカスタム Dataset 関数から詳細な分析を行っていきます。

序文

torch.utils.dataPyTorchデータの処理とロードのために提供されるモジュールです。このモジュールは、データセットの作成、操作、一括読み込みのためのユーティリティ クラスと関数のセットを提供します。

torch.utils.dataモジュール内でよく使用されるクラスと関数をいくつか示します。

  • Dataset: 抽象データセット クラスを定義し、ユーザーはこのクラスを継承して独自のデータセットを構築できます。Datasetこのクラスには、実装する必要がある 2 つのメソッドが用意されています。1 つは__getitem__個々のサンプルにアクセスするため、もう 1 つ__len__はデータセットのサイズを返すためです。
  • TensorDataset:Datasetクラスから継承され、テンソル データをデータセットにパックするために使用されます。複数のテンソルを入力として受け取り、最初の入力テンソルのサイズに従ってデータセットのサイズを決定します。
  • DataLoader: データセットのバッチロードに使用されるデータローダークラス。データセット オブジェクトを入力として受け入れ、バッチ サイズの設定、マルチスレッド データ ロード、データ シャッフルなどのさまざまなデータ ロードおよび前処理機能を提供します。
  • Subset: データセットのサブセット クラス。データセットから指定されたサンプルを選択するために使用されます。
  • random_split: データセットを複数のサブセットにランダムに分割します。分割の比率や各サブセットのサイズを指定できます。
  • ConcatDataset: 複数のデータセットを結合して、より大きなデータセットを形成します。
  • get_worker_info: 現在のデータローダのプロセス情報を取得します。

上記のクラスと関数に加えて、torch.utils.dataランダム トリミング、ランダム回転、標準化など、一般的に使用されるデータ前処理ツールも提供されます。

torch.utils.dataモジュールによって提供されるクラスと関数を通じて、データを簡単にロード、処理、バッチロードできるため、モデルのトレーニングと検証が容易になります。ただし、最も頻繁に使用する 2 つのクラスは、DatasetクラスとDataLoaderクラスです。

1. カスタム データセット クラス

torch.utils.data.Datasetこれは、PyTorch でデータセットを表すために使用される抽象クラスであり、データセットのアクセス方法とサンプル数を定義するために使用されます。

Dataset クラスは基本クラスであり、このクラスを継承し、次の 2 つのメソッドを実装することでカスタム データセット クラスを作成できます。

getitem (self,index): 指定されたインデックスインデックスに従って、対応するサンプルデータを返します。インデックスは、サンプルを順番に取得することを意味する整数にすることも、ファイル名でサンプルを取得するなどの他の方法にすることもできます。
len (self): データセット内のサンプル数を返します。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

上記のコード サンプルは主に自定义Dataset数据集类メソッドを実装しています。このメソッドは通常、独自のデータをトレーニングする必要がある場合に定義されます。ただし、一般的に、ディープ ラーニングの初心者は MNIST、CIFAR-10 などを使用します内置数据集。現時点では、Dataset クラスを自分で定義する必要はありません。その理由については、以下で詳しく説明します。

2、torchvision.datasets

PyTorch の組み込みデータセットを使用したい場合は、通常、torchvision.datasetsモジュールを通じて実行します。torchvision.datasetsこのモジュールは、MNIST、CIFAR10、ImageNet など、一般的に使用される多くのコンピューター ビジョン データセットを提供します。

以下は組み込みデータセットを使用したサンプルコードです。

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

上記のコードで実装したのは、組み込みの MNIST (手書き数字) データセットの読み込みと使用です。ご覧のとおり、ここでは上記のクラスを使用しませんでしたがtorch.utils.data.Dataset、それはなぜでしょうか?

これはtorchvision.datasets、モジュール内で、組み込みデータセット クラスがすでにtorch.utils.data.Datasetインターフェイスを実装しており、使用可能なデータセット オブジェクトを直接返すためです。したがって、組み込みデータセットを使用する場合、クラスを明示的に継承せずに、組み込みデータセット クラスを直接インスタンス化できますtorch.utils.data.Dataset

などの組み込みデータセット クラスの実装には、とメソッドの定義がtorchvision.datasets.MNISTすでに含まれており、これらを使用してサンプルを取得し、組み込みデータセット オブジェクトから直接データセットのサイズを決定できます。このようにして、組み込みデータセットを使用する場合、データの読み込みとバッチ処理のために組み込みデータセット オブジェクトを に直接渡すことができます。__getitem____len__torch.utils.data.DataLoader

組み込みデータセットの背後では、torch.utils.data.Dataset便宜上、またより多くの機能を提供するために、依然としてクラスに基づいて実装されています。PyTorch は、こ​​れらの一般的に使用されるデータセットを組み込みデータセット クラスにカプセル化します。

この目的のために、下の図に示すように、 pytorch 公式 Web サイトにアクセスして、組み込みデータセットの読み込みコードを確認しました。
ここに画像の説明を挿入
Dataset データセット クラスが実際に組み込みであることがわかります。

3、データローダー

torch.utils.data.DataLoaderPyTorchにデータを一括ロードするためのツールクラスです。データセット オブジェクト ( のtorch.utils.data.Datasetサブクラスなど) を受け入れ、データのロード、バッチ処理、データ シャッフルなどのさまざまな機能を提供します。

以下は、 のtorch.utils.data.DataLoader一般的に使用されるパラメータと関数です。

  • dataset: データセット オブジェクト。 のtorch.utils.data.Datasetサブクラス オブジェクトにすることができます。
  • batch_size: バッチごとのサンプル数、デフォルトは 1 です。
  • shuffle: データをシャッフルするかどうか。デフォルトは ですFalseデータはエポックごとにシャッフルされます。
  • num_workers: データのロードに使用される子プロセスの数。デフォルトは 0 で、これはメイン プロセスでデータをロードすることを意味します。実際、Windows システムでは 0 に設定されますが、Linux では 0 より大きい数値に設定できます。
  • collate_fn: バッチデータを返す前に各サンプルを処理する機能。「はい」の場合Nonetorch.utils.data._utils.collate.default_collateデフォルトでその関数を処理に使用します。
  • drop_last: 最後のサンプル サイズが 1 バッチに満たないデータを破棄するかどうか。デフォルトは ですFalse
  • pin_memory: ロードしたデータを CUDA に対応した固定メモリに保存するかどうか。デフォルトは ですFalse
  • prefetch_factor: プリフェッチ係数。デバイスにデータをプリフェッチするために使用されます。デフォルトは 2 です。
  • persistent_workers: true の場合True、エポックごとにデータをロードするために永続的なサブプロセスを使用します。デフォルトは ですFalse

サンプルコードは次のとおりです。

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transformsmodule は、PyTorch で画像データの前処理を行うための機能モジュールです。画像データのロード、トレーニング、推論時にさまざまな一般的なデータ変換および拡張操作を実行するための一連の変換関数を提供します。以下は、一般的に使用されるいくつかの変換関数の詳細な説明です。

  1. サイズ変更: 画像のサイズを変更します

    • Resize(size): 画像のサイズを指定されたサイズに変更します。短辺のサイズとして整数を受け入れることも、画像のターゲット サイズとしてタプルまたはリストを受け入れることもできます。
  2. ToTensor: 画像をテンソルに変換します

    • ToTensor(): 画像をテンソルに変換し、0 ~ 255 から 0 ~ 1 の範囲のピクセル値をマッピングします。画像データを深層学習モデルに渡すのに適しています。
  3. 正規化:画像データを正規化します。

    • Normalize(mean, std):画像データを正規化します。渡される平均値と標準偏差は、ピクセル値正規化の平均値と標準偏差です。平均値と標準値は以前に使用したデータセットに対応する必要があることに注意してください。
  4. Randomhorizo​​ntalFlip: ランダムな水平反転画像

    • RandomHorizontalFlip(p=0.5):一定の確率で画像をランダムに左右反転します。Probability p は反転の確率を制御し、デフォルトは 0.5 です。
  5. RandomCrop: 画像をランダムにトリミングします。

    • RandomCrop(size, padding=None): 画像を指定されたサイズにランダムにトリミングします。タプルまたは整数をターゲット サイズとして指定でき、オプションでパディング値を指定できます。
  6. ColorJitter: カラージッター

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0): 画像の明るさ、コントラスト、彩度、色相をランダムに調整します。画像の外観は、さまざまなパラメータを設定することで調整できます。

使用する際にはtransforms.Composeこれらのデータ処理を組み合わせて使用​​することが多いですが、使用する場合はその組み合わせを直接呼び出すだけです。

サンプルコードは次のとおりです。

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

5. 画像分類における Dataset データセットクラスの定義

例として眼疾患データセットを取り上げます (詳細については、深層学習実践の基本的なケース - SqueezeNet に基づく畳み込みニューラル ネットワーク (CNN) 眼疾患認識 | 例 1 を参照してください)。データセットにラベルを付けた後にトレインを生成しました。および valid.txt ファイルでは、このファイルには 2 つの列があり、次のように、最初の列はデータ セットのパス、2 番目の列はデータ セットのラベル (つまり、カテゴリ) です。 、
ここに画像の説明を挿入
独自のデータセット読み取りクラスを定義できます。具体的なコードは次のとおりです。

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)


class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]

        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs_info)

独自のデータセット読み取りクラスを定義した後、txt ファイルを渡してデータセットを前処理して読み取ることができます。カスタム データセット クラスで最も重要な 3 つのメソッドは __init__()、getitem ()、および __len__() であり、これらはすべて不可欠です。同時に、変換のデータ強調操作は必要ありません。これはモデルのパフォーマンスを向上させるための単なる方法ですが、現在のモデル トレーニング プロセスでは通常、データ強調操作が追加されます。

# 加载训练集和验证集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

上記では、カスタム MyDataset クラスを通じて train.txt ファイルと valid.txt ファイルをそれぞれロードしました (次の True パラメーターはトレーニング セットのデータを強化することを意味し、False は検証セットのデータを強化することを意味します)。 。次に、DataLoader を使用してデータセットをバッチロードし、ロードされたデータをトレーニング用のモデルtrain_dl に直接スローできます。test_dl


具体的な例は以下を参照してください。

おすすめ

転載: blog.csdn.net/m0_63007797/article/details/132385283