Translation: 3.5. Image Classification Dataset Fashion-MNIST pytorch

One of the widely used datasets for image classification is the MNIST dataset [LeCun et al., 1998]. While it works well as a benchmark dataset, even simple models achieve over 95% classification accuracy by today's standards, making it unsuitable for distinguishing between strong and weak models. Today, MNIST is more of a sanity check than a benchmark. To up the ante, we focus our discussion in the following sections on the similar quality but relatively complex Fashion-MNIST dataset [Xiao et al., 2017], which was released in 2017.

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

3.5.1 Read the dataset

We can download the Fashion-MNIST dataset and read it into memory via built-in functions in the framework.

# `ToTensor` converts the image data from PIL type to 32-bit floating point
# tensors. It divides all numbers by 255 so that all pixel values are between
# 0 and 1
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

Fashion-MNIST consists of images from 10 categories, each represented by 6000 images in the training dataset and 1000 images in the testing dataset. The test dataset (or test set) is used to evaluate model performance, not for training. Therefore, the training and test sets contain 60,000 and 10,000 images, respectively.

len(mnist_train), len(mnist_test)
(60000, 10000)

The height and width of each input image are 28 pixels. Note that this dataset consists of grayscale images with a channel count of 1. For brevity, in this book we store the height h, width w, pixels of any image that has a height ashxw

mnist_train[0][0].shape
torch.Size([1, 28, 28])

The images in Fashion-MNIST are associated with the following categories: t-shirts, pants, jumpers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots. The following function converts between numeric label indices and their names in text.

def get_fashion_mnist_labels(labels):  #@save
    """Return text labels for the Fashion-MNIST dataset."""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

We can now create a function to visualize these examples.

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # Tensor Image
            ax.imshow(img.numpy())
        else:
            # PIL Image
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

Below are the images and their corresponding labels (in text form) for the first few examples in the training dataset.

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

insert image description here

3.5.2. Reading minibatches

To make it easier for us to read the train and test sets, we use the built-in data iterators instead of creating one from scratch. Recall that in each iteration, the data iterator reads mini-batches of data with size batch_size each time. We also randomly shuffle the examples of the training data iterator.

batch_size = 256

def get_dataloader_workers():  #@save
    """Use 4 processes to read the data."""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

Let's see how long it takes to read the training data.

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{
      
      timer.stop():.2f} sec'
'2.46 sec'

3.5.3. Putting everything together

Now we define the function load_data_fashion_mnist to get and read the Fashion-MNIST dataset. It returns data iterators for training and validation sets. Also, it accepts an optional parameter to resize the image to another shape.

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """Download the Fashion-MNIST dataset and then load it into memory."""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

Below we test the image resizing capabilities of the function load_data_fashion_mnist by specifying the resize parameter.

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

We are now ready to use the Fashion-MNIST dataset in the next sections.

3.5.4. generalize

  • Fashion-MNIST is a clothing classification dataset consisting of images representing 10 categories. We will use this dataset in subsequent chapters and chapters to evaluate various classification algorithms.

  • We use height to store the height h, width w, pixels of any image hxw.

  • Data iterators are a key component of efficient performance. Rely on well-implemented data iterators that take advantage of high-performance computing to avoid slowing down the training loop.

3.5.5. practise

  1. Does reducing batch_size (e.g. to 1) affect read performance?
    The total number of reads is the same, and the total number of jobs is the same. One of the purposes of batch_size is for parallelism, and the other is to reduce reading too much data at one time, which requires too much memory storage.

  2. The performance of data iterators is important. Do you think the current implementation is fast enough? Explore various improvement options.

  3. Check out the framework's online API documentation. What other datasets are available?
    https://pytorch.org/docs/stable/torchvision/datasets.html
    Datasets:

MNIST
Fashion-MNIST
KMNIST
EMNIST
QMNIST
FakeData
COCO:Captions,Detection
LSUN
ImageFolder
DatasetFolder
ImageNet
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
SBD
USPS
Kinetics-400
HMDB51
UCF101
CelebA

refer to

https://d2l.ai/chapter_linear-networks/image-classification-dataset.html

Guess you like

Origin blog.csdn.net/zgpeace/article/details/123837420