PyTorchの基本(4)-----データの読み込みと前処理

序文

PyTorchのTensor、Autograd、torch.nn、torch.optimパッケージについて簡単に説明しました。これらを使用すると、ネットワークモデルを簡単に構築できますが、これだけでは不十分です。大量のデータも必要です。ご存知のとおり、データ深さです。学習の精神である深層学習モデルは、データによって「供給」されます。この記事では、データの読み込みと前処理について説明します。

  • まず、トーチパッケージを紹介する必要があります
import torch
torch.__version__

1.データの読み込み

PyTorchは、torch.utils.dataを介して一般的に使用されるデータの読み込みをカプセル化します。これにより、マルチスレッドデータの先読みとバッチ読み込みを簡単に実現できます。

1.1データセット

データセットは抽象クラスです。読みやすくするために、使用するデータをデータセットクラスとしてパッケージ化する必要があります。カスタムデータセットクラスはそれを継承し、2つのメンバーメソッドを実装する必要があります。

  • 1 .__ getitem __():このメソッドは、インデックス(0-len(self))を持つデータまたはサンプルを取得するように定義します。
  • 2 .__ len __():このメソッドは、データセットの全長を返します

以下では、Kaggleのブルドーザー向けの競争ブルーブックを使用してデータセットカスタマイズします。紹介の便宜上、内部のデータディクショナリを使用して説明します。

  • まず、関連するパッケージを参照する必要があります
from torch.utils.data import Dataset
import pandas as pd
  • データセットをカスタマイズする
#定义一个数据集
class BulldozerDataset(Dataset):
    """ 数据集演示 """
    def __init__(self, csv_file):
        """实现初始化方法,在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        '''
        返回df的长度
        '''
        return len(self.df)
    def __getitem__(self, idx):
        '''
        根据 idx 返回一行数据
        '''
        return self.df.iloc[idx].SalePrice
  • この時点で、データセットが定義され、アクセスするオブジェクトをインスタンス化できます。
ds_demo= BulldozerDataset('median_benchmark.csv')
  • 次のコマンドを直接使用して、データセットデータを表示できます
# 前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
  • インデックスを使用して、対応するデータに直接アクセスします
ds_demo[0]

カスタムデータセットが作成されました。以下では、公式のデータローダーを使用してデータを読み取ります

1.2 DataLoader

DataLoaderは、データセットの読み取り操作を提供します。一般的なパラメーターは、batch_size(各バッチのサイズ)、shuffle(シャッフル操作を実行するかどうか)、num_workers(データのロード時に複数のサブプロセスを使用)です。これが簡単なデモンストレーションです:

dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)

DataLoaderが返すのは反復可能なオブジェクトであり、イテレータを使用して段階的にデータを取得できます

idata=iter(dl)
print(next(idata))

一般的な使用法は、forループを使用してトラバースすることです。

for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间,这里只循环一遍
    break

この時点で、データセットを介してデータセットを定義し、DataLorderを使用してデータセットをロードおよびトラバースできます。

2、トーチビジョンパッケージ

torchvisionは、PyTorchでの画像処理専用のライブラリです。PyTorch公式Webサイトのインストールチュートリアルの最後のpip install torchvisionは、このパッケージをインストールすることです。
Torchvisionは、以前に使用されていたCIFAR-10、ImageNet、COCO、MNIST、LSUN、およびtorchvision.datasetsから簡単に呼び出すことができるその他のデータセットを含む、一般的な画像データセットを事前に実装しています。

  • torchvisionがプリインストールされているデータセットの概要は次のとおりです。
データセット名
MNIST
COCO
CIFAR-10
ImageNet
キャプション
検出
LSUN
ImageFolder
Imagenet-12
STL10
SVHN
PhotoTour

PyTorchに付属するデータセットは、2つの上位レベルのAPI、つまりtorchvisionとtorchtextによって提供されます。

  • Torchvisionは、画像データ処理に関連するデータとAPIを提供します
    • データの場所:torchvision.datasets ;例:torchvision.datasets.MNIST
  • Torchtextは、テキストデータ処理に関連するデータとAPIを提供します
    • データの場所:torchtext.datasets;例:torchtext.datasets.IMDB

簡単なデモンストレーションをしましょう

  • まず、トーチビジョンパッケージを紹介する必要があります
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                                      train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                                      download=True, # 表示是否自动下载 MNIST 数据集
                                      transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

2.1 torchvision.models

Torchvisionは、一般的に使用される画像データセットを提供するだけでなく、ロード直後に使用したり、学習を転送し続けたりできる、トレーニング済みのネットワークモデルも提供します。torchvision.modelsモジュールのサブモジュールには、次のモデルが含まれています。

ネットワークモデル
AlexNet
VGG
ResNet
SqueezeNet
DenseNet

トレーニング済みのモデルを直接使用できます。もちろん、これはデータセットと同じであり、サーバーからダウンロードする必要があります。

  • まず、torchvision.modelsをインポートする必要があります
import torchvision.models as models
  • 直接使用する
resnet18 = models.resnet18(pretrained=True)

2.2 torchvision.tranforms

変換モジュールは、データ処理とデータ拡張のための一般的な画像変換操作クラスを提供します

  • まず、torchvision.tranformsを紹介してから、簡単なデモンストレーションを行う必要があります
from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

誰かが間違いなく尋ねるでしょう:(0.485、0.456、0.406)、(0.2023、0.1994、0.2010)これらの数字はどういう意味ですか?
公式投稿には詳細な手順があります:https//discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21これらはImageNetでトレーニングされた正規化パラメーターであり、直接使用できます。これは固定値。
この時点で、PyTorchの基本的なコンテンツの紹介は完了です。

参照

https://github.com/zergtant/pytorch-handbook/blob/master/chapter2

おすすめ

転載: blog.csdn.net/dongjinkun/article/details/113869697
おすすめ