Справка
3.5. Набор данных классификации изображений
Перед тем, как представить реализацию регрессии shftmax, мы сначала вводим набор данных для классификации изображений с несколькими классами.
В этой главе мы начали использовать pytorch для его реализации ~
В этом разделе мы будем использовать пакет torchvision, который обслуживает среду глубокого обучения PyTorch и в основном используется для создания моделей компьютерного зрения. Torchvision в основном состоит из следующих частей:
torchvision.datasets
: Некоторые функции для загрузки данных и часто используемые интерфейсы наборов данныхtorchvision.models
: Содержит часто используемые модели (включая предварительно обученные модели), такие как AlexNet, VGG, ResNet и т. Д.torchvision.transforms
: Часто используемые преобразования изображений, такие как обрезка, поворот и т.д .;torchvision.utils
: Некоторые другие полезные методы
3.5.1. Получение набора данных
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..') # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l
Ниже мы torchvision.dataset
загружаем этот набор данных через torchvision . Данные будут автоматически загружены из Интернета при первом вызове. Мы используем параметры, train
чтобы указать, чтобы получить обучающий набор или набор данных тестирования (данные тестирования).
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
Если загрузка набора данных идет медленно, рекомендуется загрузить 4 через веб-сайт , а затем импортировать его в соответствии с загруженным каталогом. Как показано на рисунке ниже, я загрузил (4 файла, здесь упоминается только один, другие похожи) train-images-idx3-ubyte.gz в локальный каталог C: \ Users \ 1 / Datasets / FashionMNIST \ FashionMNIST \ raw \. Загрузите 4 прямо из URL-адреса (не нужно распаковывать) в каталог и выполните приведенный выше код.
# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train)) # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test)) # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label) # torch.Size([1, 28, 28]) tensor(5)
Переменная feature
соответствует изображению, высота и ширина которого составляют 28 пикселей. Поскольку мы использовали его transforms.ToTensor()
, значение каждого пикселя представляет собой 32-битное число с плавающей запятой [0.0, 1.0]. Следует отметить, что размер элемента равен (C * H * W), а не (H * W * C). Первое измерение - это количество каналов, потому что количество каналов данных равно 1. Следующие два измерения - это высота и ширина изображения.
Всего в Fashion-MNIST 10 категорий: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9.
# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
# text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
return [text_labels[int(i)] for i in labels]
# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
_, figs = plt.subplots(1, len(images), figsize= (12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X, y = [], []
for i in range(10):
# 从数据集中取出10个
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
3.5.2. Чтение небольших партий
Мы обучим модель на обучающем наборе данных и оценим производительность модели на тестовом наборе. Как упоминалось ранее, это mnist_train
является torch.utils.data.Dataset
подклассом, так что мы можем передать его в , torch.utils.data.DataLoader
чтобы создать экземпляр DataLoader , который читает небольшие партии образцов данных
batch_size = 256
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用额外的进程来加速读取数据
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
start = time.time()
for X,y in train_iter:
continue
print('%.2f sec' % (time.time() - start))