[pytorch、learning] -3.5画像分類データセット

参照

3.5。画像分類データセット

shftmax回帰の実装を紹介する前に、まずマルチクラス画像分類データセットを紹介します

この章では、pytorchを使用して実装し始めました〜

このセクションでは、PyTorchディープラーニングフレームワークを提供し、主にコンピュータービジョンモデルの構築に使用されるトーチビジョンパッケージを使用します。Torchvisionは、主に次の部分で構成されています。

  1. torchvision.datasets:データをロードするためのいくつかの関数と一般的に使用されるデータセットインターフェイス
  2. torchvision.models:AlexNet、VGG、ResNetなどの一般的に使用されるモデル(事前トレーニング済みモデルを含む)が含まれています。
  3. torchvision.transforms:トリミング、回転など、一般的に使用される画像変換。
  4. torchvision.utils:他のいくつかの便利な方法

3.5.1。データセットを取得する

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l

以下では、torchvision.datasetこのデータセットをtorchvisionからダウンロードします。データは、最初に呼び出されたときにインターネットから自動的にダウンロードされます。パラメータを使用trainして、トレーニングセットまたはテストデータセット(テストデータ)を取得するように指定します。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

データセットのダウンロードが遅い場合は、Webサイトから4ダウンロードし、ダウンロードしたディレクトリに従ってインポートすることをお勧めします。次の図に示すように、train-images-idx3-ubyte.gzをローカルディレクトリC:\ Users \ 1 / Datasets / FashionMNIST \ FashionMNIST \ rawにダウンロードしました(4つのファイル、ここでは1つだけ、他は同様です)。 \。4をURLから直接(解凍する必要はありません)ディレクトリにダウンロードし、上記のコードを実行します。

ここに画像の説明を挿入

# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train))   # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test))   # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label)   # torch.Size([1, 28, 28]) tensor(5)

ここに画像の説明を挿入
この変数featureは、高さと幅が両方とも28ピクセルの画像に対応します使用したのでtransforms.ToTensor()、各ピクセルの値は32ビットの浮動小数点数[0.0、1.0]です。フィーチャーのサイズは(H * W * C)ではなく(C * H * W)であることに注意してください。データチャネルの数は1であるため、最初の次元はチャネルの数です。次の2つの次元は、画像の高さと幅です。

ファッションには合計10のカテゴリがあります-MNIST、0、1、2、3、4、5、6、7、8、9

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
#     text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    return [text_labels[int(i)] for i in labels]

# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    _, figs = plt.subplots(1, len(images), figsize= (12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

X, y = [], []
for i in range(10):
    # 从数据集中取出10个
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

ここに画像の説明を挿入

3.5.2。小さなバッチを読む

トレーニングデータセットでモデルをトレーニングし、テストセットでモデルのパフォーマンスを評価します。前述のように、これmnist_traintorch.utils.data.Datasetサブクラスであるため、これを渡して、torch.utils.data.DataLoaderデータサンプルの小さなバッチを読み取るDataLoaderインスタンスを作成できます。

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

start = time.time()
for X,y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/piano9425/article/details/107150167