Traducción: 3.5 Conjunto de datos de clasificación de imágenes Fashion-MNIST pytorch

Uno de los conjuntos de datos más utilizados para la clasificación de imágenes es el conjunto de datos MNIST [LeCun et al., 1998]. Si bien funciona bien como conjunto de datos de referencia, incluso los modelos simples logran una precisión de clasificación de más del 95 % según los estándares actuales, lo que lo hace inadecuado para distinguir entre modelos fuertes y débiles. Hoy, MNIST es más un control de cordura que un punto de referencia. Para aumentar la apuesta, enfocamos nuestra discusión en las siguientes secciones en el conjunto de datos Fashion-MNIST de calidad similar pero relativamente complejo [Xiao et al., 2017], que se publicó en 2017.

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

3.5.1 Leer el conjunto de datos

Podemos descargar el conjunto de datos Fashion-MNIST y leerlo en la memoria a través de funciones integradas en el marco.

# `ToTensor` converts the image data from PIL type to 32-bit floating point
# tensors. It divides all numbers by 255 so that all pixel values are between
# 0 and 1
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

Fashion-MNIST consta de imágenes de 10 categorías, cada una representada por 6000 imágenes en el conjunto de datos de entrenamiento y 1000 imágenes en el conjunto de datos de prueba. El conjunto de datos de prueba (o conjunto de prueba) se utiliza para evaluar el rendimiento del modelo, no para el entrenamiento. Por lo tanto, los conjuntos de entrenamiento y prueba contienen 60 000 y 10 000 imágenes, respectivamente.

len(mnist_train), len(mnist_test)
(60000, 10000)

La altura y el ancho de cada imagen de entrada son 28 píxeles. Tenga en cuenta que este conjunto de datos consta de imágenes en escala de grises con un recuento de canales de 1. Para abreviar, en este libro almacenamos el alto h, el ancho y los wpíxeles de cualquier imagen que tenga un alto comohxw

mnist_train[0][0].shape
torch.Size([1, 28, 28])

Las imágenes de Fashion-MNIST están asociadas a las siguientes categorías: camisetas, pantalones, jerséis, vestidos, abrigos, sandalias, camisas, zapatillas, bolsos y botines. La siguiente función convierte entre índices de etiquetas numéricas y sus nombres en texto.

def get_fashion_mnist_labels(labels):  #@save
    """Return text labels for the Fashion-MNIST dataset."""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

Ahora podemos crear una función para visualizar estos ejemplos.

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # Tensor Image
            ax.imshow(img.numpy())
        else:
            # PIL Image
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

A continuación se encuentran las imágenes y sus etiquetas correspondientes (en forma de texto) para los primeros ejemplos en el conjunto de datos de entrenamiento.

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

inserte la descripción de la imagen aquí

3.5.2 Lectura de minilotes

Para facilitarnos la lectura de los conjuntos de entrenamiento y prueba, usamos los iteradores de datos integrados en lugar de crear uno desde cero. Recuerde que en cada iteración, el iterador de datos lee mini lotes de datos con tamaño batch_size cada vez. También mezclamos aleatoriamente los ejemplos del iterador de datos de entrenamiento.

batch_size = 256

def get_dataloader_workers():  #@save
    """Use 4 processes to read the data."""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

Veamos cuánto se tarda en leer los datos de entrenamiento.

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{
      
      timer.stop():.2f} sec'
'2.46 sec'

3.5.3 Poniendo todo junto

Ahora definimos la función load_data_fashion_mnist para obtener y leer el conjunto de datos Fashion-MNIST. Devuelve iteradores de datos para conjuntos de entrenamiento y validación. Además, acepta un parámetro opcional para cambiar el tamaño de la imagen a otra forma.

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """Download the Fashion-MNIST dataset and then load it into memory."""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

A continuación, probamos las capacidades de cambio de tamaño de imagen de la función load_data_fashion_mnist especificando el parámetro de cambio de tamaño.

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

Ahora estamos listos para usar el conjunto de datos Fashion-MNIST en las siguientes secciones.

3.5.4. generalizar

  • Fashion-MNIST es un conjunto de datos de clasificación de ropa que consta de imágenes que representan 10 categorías. Usaremos este conjunto de datos en capítulos y capítulos posteriores para evaluar varios algoritmos de clasificación.

  • Usamos height para almacenar la altura h, el ancho y los wpíxeles de cualquier imagen hxw.

  • Los iteradores de datos son un componente clave del rendimiento eficiente. Confíe en iteradores de datos bien implementados que aprovechan la computación de alto rendimiento para evitar ralentizar el ciclo de entrenamiento.

3.5.5. práctica

  1. ¿Reducir el tamaño del lote (por ejemplo, a 1) afecta el rendimiento de lectura?
    El número total de lecturas es el mismo y el número total de trabajos es el mismo. Uno de los propósitos de batch_size es el paralelismo y el otro es reducir la lectura de demasiados datos a la vez, lo que requiere demasiado almacenamiento de memoria.

  2. El rendimiento de los iteradores de datos es importante. ¿Crees que la implementación actual es lo suficientemente rápida? Explore varias opciones de mejora.

  3. Consulte la documentación de la API en línea del marco. ¿Qué otros conjuntos de datos están disponibles?
    https://pytorch.org/docs/stable/torchvision/datasets.html
    Conjuntos de datos:

MNIST
Fashion-MNIST
KMNIST
EMNIST
QMNIST
FakeData
COCO:Captions,Detection
LSUN
ImageFolder
DatasetFolder
ImageNet
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
SBD
USPS
Kinetics-400
HMDB51
UCF101
CelebA

Referirse a

https://d2l.ai/chapter_linear-networks/image-classification-dataset.html

Supongo que te gusta

Origin blog.csdn.net/zgpeace/article/details/123837420
Recomendado
Clasificación