翻译: 3.5. 图像分类数据集Fashion-MNIST pytorch

用于图像分类的广泛使用的数据集之一是 MNIST 数据集[LeCun et al., 1998]。虽然它作为基准数据集运行良好,但即使按照今天的标准,即使是简单的模型也能达到 95% 以上的分类准确率,这使得它不适合区分强模型和弱模型。今天,MNIST 更多的是作为健全性检查而不是基准。为了提高赌注,我们将在接下来的部分中将讨论重点放在质量相似但相对复杂的 Fashion-MNIST 数据集 [Xiao et al., 2017]上,该数据集于 2017 年发布。

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

3.5.1 读取数据集

我们可以通过框架中的内置函数下载 Fashion-MNIST 数据集并将其读入内存。

# `ToTensor` converts the image data from PIL type to 32-bit floating point
# tensors. It divides all numbers by 255 so that all pixel values are between
# 0 and 1
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

Fashion-MNIST 由来自 10 个类别的图像组成,每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像表示。测试数据集(或测试集)用于 评估模型性能,而不是用于训练。因此,训练集和测试集分别包含 60000 和 10000 张图像。

len(mnist_train), len(mnist_test)
(60000, 10000)

每个输入图像的高度和宽度都是 28 像素。请注意,该数据集由灰度图像组成,其通道数为 1。为简洁起见,在本书中,我们存储了具有高度的任何图像的高度h,宽度w,像素为hxw

mnist_train[0][0].shape
torch.Size([1, 28, 28])

Fashion-MNIST 中的图像与以下类别相关联:T 恤、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和踝靴。以下函数在数字标签索引和它们在文本中的名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """Return text labels for the Fashion-MNIST dataset."""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

我们现在可以创建一个函数来可视化这些示例。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # Tensor Image
            ax.imshow(img.numpy())
        else:
            # PIL Image
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

以下是训练数据集中前几个示例的图像及其对应的标签(以文本形式)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

3.5.2. 读取小批量

为了让我们在读取训练集和测试集时更轻松,我们使用内置的数据迭代器,而不是从头开始创建一个。回想一下,在每次迭代中,数据迭代器每次都会读取具有大小的小批量数据batch_size。我们还随机打乱训练数据迭代器的示例。

batch_size = 256

def get_dataloader_workers():  #@save
    """Use 4 processes to read the data."""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

让我们看看读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{
      
      timer.stop():.2f} sec'
'2.46 sec'

3.5.3. 把所有东西放在一起

现在我们定义load_data_fashion_mnist获取和读取 Fashion-MNIST 数据集的函数。它返回训练集和验证集的数据迭代器。此外,它接受一个可选参数以将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """Download the Fashion-MNIST dataset and then load it into memory."""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

下面我们 load_data_fashion_mnist通过指定resize 参数来测试函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

我们现在已准备好在接下来的部分中使用 Fashion-MNIST 数据集。

3.5.4。概括

  • Fashion-MNIST 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续章节和章节中使用这个数据集来评估各种分类算法。

  • 我们用高度存储任何图像的高度h,宽度w ,像素hxw.

  • 数据迭代器是高效性能的关键组件。依靠利用高性能计算的良好实现的数据迭代器来避免减慢训练循环。

3.5.5。练习

  1. 减少batch_size(例如,减少到 1)会影响阅读性能吗?
    读取的总数是一样的,工作总量是一样的。batch_size的目的一个是为了并行,一个是为了减少一次读太多数据,对内存存储要求太高。

  2. 数据迭代器的性能很重要。您认为当前的实施速度是否足够快?探索各种改进方案。

  3. 查看框架的在线 API 文档。还有哪些其他数据集可用?
    https://pytorch.org/docs/stable/torchvision/datasets.html
    Datasets:

MNIST
Fashion-MNIST
KMNIST
EMNIST
QMNIST
FakeData
COCO:Captions,Detection
LSUN
ImageFolder
DatasetFolder
ImageNet
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
SBD
USPS
Kinetics-400
HMDB51
UCF101
CelebA

参考

https://d2l.ai/chapter_linear-networks/image-classification-dataset.html

猜你喜欢

转载自blog.csdn.net/zgpeace/article/details/123837420