[pytorch, learning] -3.5 набор данных классификации изображений

Справка

3.5. Набор данных классификации изображений

Перед тем, как представить реализацию регрессии shftmax, мы сначала вводим набор данных для классификации изображений с несколькими классами.

В этой главе мы начали использовать pytorch для его реализации ~

В этом разделе мы будем использовать пакет torchvision, который обслуживает среду глубокого обучения PyTorch и в основном используется для создания моделей компьютерного зрения. Torchvision в основном состоит из следующих частей:

  1. torchvision.datasets: Некоторые функции для загрузки данных и часто используемые интерфейсы наборов данных
  2. torchvision.models: Содержит часто используемые модели (включая предварительно обученные модели), такие как AlexNet, VGG, ResNet и т. Д.
  3. torchvision.transforms: Часто используемые преобразования изображений, такие как обрезка, поворот и т.д .;
  4. torchvision.utils: Некоторые другие полезные методы

3.5.1. Получение набора данных

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l

Ниже мы torchvision.datasetзагружаем этот набор данных через torchvision . Данные будут автоматически загружены из Интернета при первом вызове. Мы используем параметры, trainчтобы указать, чтобы получить обучающий набор или набор данных тестирования (данные тестирования).

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

Если загрузка набора данных идет медленно, рекомендуется загрузить 4 через веб-сайт , а затем импортировать его в соответствии с загруженным каталогом. Как показано на рисунке ниже, я загрузил (4 файла, здесь упоминается только один, другие похожи) train-images-idx3-ubyte.gz в локальный каталог C: \ Users \ 1 / Datasets / FashionMNIST \ FashionMNIST \ raw \. Загрузите 4 прямо из URL-адреса (не нужно распаковывать) в каталог и выполните приведенный выше код.

Вставьте описание изображения сюда

# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train))   # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test))   # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label)   # torch.Size([1, 28, 28]) tensor(5)

Вставьте описание изображения сюда
Переменная featureсоответствует изображению, высота и ширина которого составляют 28 пикселей. Поскольку мы использовали его transforms.ToTensor(), значение каждого пикселя представляет собой 32-битное число с плавающей запятой [0.0, 1.0]. Следует отметить, что размер элемента равен (C * H * W), а не (H * W * C). Первое измерение - это количество каналов, потому что количество каналов данных равно 1. Следующие два измерения - это высота и ширина изображения.

Всего в Fashion-MNIST 10 категорий: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9.

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
#     text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    return [text_labels[int(i)] for i in labels]

# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    _, figs = plt.subplots(1, len(images), figsize= (12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

X, y = [], []
for i in range(10):
    # 从数据集中取出10个
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

Вставьте описание изображения сюда

3.5.2. Чтение небольших партий

Мы обучим модель на обучающем наборе данных и оценим производительность модели на тестовом наборе. Как упоминалось ранее, это mnist_trainявляется torch.utils.data.Datasetподклассом, так что мы можем передать его в , torch.utils.data.DataLoaderчтобы создать экземпляр DataLoader , который читает небольшие партии образцов данных

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

start = time.time()
for X,y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

Вставьте описание изображения сюда

рекомендация

отblog.csdn.net/piano9425/article/details/107150167
рекомендация