序文
PyTorchのTensor、Autograd、torch.nn、torch.optimパッケージについて簡単に説明しました。これらを使用すると、ネットワークモデルを簡単に構築できますが、これだけでは不十分です。大量のデータも必要です。ご存知のとおり、データ深さです。学習の精神である深層学習モデルは、データによって「供給」されます。この記事では、データの読み込みと前処理について説明します。
- まず、トーチパッケージを紹介する必要があります
import torch
torch.__version__
1.データの読み込み
PyTorchは、torch.utils.dataを介して一般的に使用されるデータの読み込みをカプセル化します。これにより、マルチスレッドデータの先読みとバッチ読み込みを簡単に実現できます。
1.1データセット
データセットは抽象クラスです。読みやすくするために、使用するデータをデータセットクラスとしてパッケージ化する必要があります。カスタムデータセットクラスはそれを継承し、2つのメンバーメソッドを実装する必要があります。
- 1 .__ getitem __():このメソッドは、インデックス(0-len(self))を持つデータまたはサンプルを取得するように定義します。
- 2 .__ len __():このメソッドは、データセットの全長を返します
以下では、Kaggleのブルドーザー向けの競争ブルーブックを使用してデータセットをカスタマイズします。紹介の便宜上、内部のデータディクショナリを使用して説明します。
- まず、関連するパッケージを参照する必要があります
from torch.utils.data import Dataset
import pandas as pd
- データセットをカスタマイズする
#定义一个数据集
class BulldozerDataset(Dataset):
""" 数据集演示 """
def __init__(self, csv_file):
"""实现初始化方法,在初始化的时候将数据读载入"""
self.df=pd.read_csv(csv_file)
def __len__(self):
'''
返回df的长度
'''
return len(self.df)
def __getitem__(self, idx):
'''
根据 idx 返回一行数据
'''
return self.df.iloc[idx].SalePrice
- この時点で、データセットが定義され、アクセスするオブジェクトをインスタンス化できます。
ds_demo= BulldozerDataset('median_benchmark.csv')
- 次のコマンドを直接使用して、データセットデータを表示できます
# 前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
- インデックスを使用して、対応するデータに直接アクセスします
ds_demo[0]
カスタムデータセットが作成されました。以下では、公式のデータローダーを使用してデータを読み取ります
1.2 DataLoader
DataLoaderは、データセットの読み取り操作を提供します。一般的なパラメーターは、batch_size(各バッチのサイズ)、shuffle(シャッフル操作を実行するかどうか)、num_workers(データのロード時に複数のサブプロセスを使用)です。これが簡単なデモンストレーションです:
dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)
DataLoaderが返すのは反復可能なオブジェクトであり、イテレータを使用して段階的にデータを取得できます
idata=iter(dl)
print(next(idata))
一般的な使用法は、forループを使用してトラバースすることです。
for i, data in enumerate(dl):
print(i,data)
# 为了节约空间,这里只循环一遍
break
この時点で、データセットを介してデータセットを定義し、DataLorderを使用してデータセットをロードおよびトラバースできます。
2、トーチビジョンパッケージ
torchvisionは、PyTorchでの画像処理専用のライブラリです。PyTorch公式Webサイトのインストールチュートリアルの最後のpip install torchvisionは、このパッケージをインストールすることです。
Torchvisionは、以前に使用されていたCIFAR-10、ImageNet、COCO、MNIST、LSUN、およびtorchvision.datasetsから簡単に呼び出すことができるその他のデータセットを含む、一般的な画像データセットを事前に実装しています。
- torchvisionがプリインストールされているデータセットの概要は次のとおりです。
データセット名 |
---|
MNIST |
COCO |
CIFAR-10 |
ImageNet |
キャプション |
検出 |
LSUN |
ImageFolder |
Imagenet-12 |
STL10 |
SVHN |
PhotoTour |
PyTorchに付属するデータセットは、2つの上位レベルのAPI、つまりtorchvisionとtorchtextによって提供されます。
- Torchvisionは、画像データ処理に関連するデータとAPIを提供します
- データの場所:torchvision.datasets ;例:torchvision.datasets.MNIST
- Torchtextは、テキストデータ処理に関連するデータとAPIを提供します
- データの場所:torchtext.datasets;例:torchtext.datasets.IMDB
簡単なデモンストレーションをしましょう
- まず、トーチビジョンパッケージを紹介する必要があります
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
train=True, # 表示是否加载数据库的训练集,false的时候加载测试集
download=True, # 表示是否自动下载 MNIST 数据集
transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理
2.1 torchvision.models
Torchvisionは、一般的に使用される画像データセットを提供するだけでなく、ロード直後に使用したり、学習を転送し続けたりできる、トレーニング済みのネットワークモデルも提供します。torchvision.modelsモジュールのサブモジュールには、次のモデルが含まれています。
ネットワークモデル |
---|
AlexNet |
VGG |
ResNet |
SqueezeNet |
DenseNet |
トレーニング済みのモデルを直接使用できます。もちろん、これはデータセットと同じであり、サーバーからダウンロードする必要があります。
- まず、torchvision.modelsをインポートする必要があります
import torchvision.models as models
- 直接使用する
resnet18 = models.resnet18(pretrained=True)
2.2 torchvision.tranforms
変換モジュールは、データ処理とデータ拡張のための一般的な画像変換操作クラスを提供します
- まず、torchvision.tranformsを紹介してから、簡単なデモンストレーションを行う必要があります
from torchvision import transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0,在把图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.RandomRotation((-45,45)), #随机旋转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])
誰かが間違いなく尋ねるでしょう:(0.485、0.456、0.406)、(0.2023、0.1994、0.2010)これらの数字はどういう意味ですか?
公式投稿には詳細な手順があります:https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21これらはImageNetでトレーニングされた正規化パラメーターであり、直接使用できます。これは固定値。
この時点で、PyTorchの基本的なコンテンツの紹介は完了です。
参照
https://github.com/zergtant/pytorch-handbook/blob/master/chapter2