参照
3.5。画像分類データセット
shftmax回帰の実装を紹介する前に、まずマルチクラス画像分類データセットを紹介します
この章では、pytorchを使用して実装し始めました〜
このセクションでは、PyTorchディープラーニングフレームワークを提供し、主にコンピュータービジョンモデルの構築に使用されるトーチビジョンパッケージを使用します。Torchvisionは、主に次の部分で構成されています。
torchvision.datasets
:データをロードするためのいくつかの関数と一般的に使用されるデータセットインターフェイスtorchvision.models
:AlexNet、VGG、ResNetなどの一般的に使用されるモデル(事前トレーニング済みモデルを含む)が含まれています。torchvision.transforms
:トリミング、回転など、一般的に使用される画像変換。torchvision.utils
:他のいくつかの便利な方法
3.5.1。データセットを取得する
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..') # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l
以下では、torchvision.dataset
このデータセットをtorchvisionからダウンロードします。データは、最初に呼び出されたときにインターネットから自動的にダウンロードされます。パラメータを使用train
して、トレーニングセットまたはテストデータセット(テストデータ)を取得するように指定します。
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
データセットのダウンロードが遅い場合は、Webサイトから4をダウンロードし、ダウンロードしたディレクトリに従ってインポートすることをお勧めします。次の図に示すように、train-images-idx3-ubyte.gzをローカルディレクトリC:\ Users \ 1 / Datasets / FashionMNIST \ FashionMNIST \ rawにダウンロードしました(4つのファイル、ここでは1つだけ、他は同様です)。 \。4をURLから直接(解凍する必要はありません)ディレクトリにダウンロードし、上記のコードを実行します。
# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train)) # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test)) # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label) # torch.Size([1, 28, 28]) tensor(5)
この変数feature
は、高さと幅が両方とも28ピクセルの画像に対応します。使用したのでtransforms.ToTensor()
、各ピクセルの値は32ビットの浮動小数点数[0.0、1.0]です。フィーチャーのサイズは(H * W * C)ではなく(C * H * W)であることに注意してください。データチャネルの数は1であるため、最初の次元はチャネルの数です。次の2つの次元は、画像の高さと幅です。
ファッションには合計10のカテゴリがあります-MNIST、0、1、2、3、4、5、6、7、8、9
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
# text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
return [text_labels[int(i)] for i in labels]
# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
_, figs = plt.subplots(1, len(images), figsize= (12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X, y = [], []
for i in range(10):
# 从数据集中取出10个
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
3.5.2。小さなバッチを読む
トレーニングデータセットでモデルをトレーニングし、テストセットでモデルのパフォーマンスを評価します。前述のように、これmnist_train
はtorch.utils.data.Dataset
サブクラスであるため、これを渡して、torch.utils.data.DataLoader
データサンプルの小さなバッチを読み取るDataLoaderインスタンスを作成できます。
batch_size = 256
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用额外的进程来加速读取数据
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
start = time.time()
for X,y in train_iter:
continue
print('%.2f sec' % (time.time() - start))