[pytorch, learning]-3.5 image classification data set

reference

3.5. Image classification data set

Before introducing the implementation of shftmax regression, we first introduce a multi-class image classification data set

This chapter started to use pytorch to implement it~

In this section, we will use the torchvision package, which serves the PyTorch deep learning framework and is mainly used to build computer vision models. Torchvision mainly consists of the following parts:

  1. torchvision.datasets: Some functions for loading data and commonly used data set interfaces
  2. torchvision.models: Contains commonly used models (including pre-trained models), such as AlexNet, VGG, ResNet, etc.
  3. torchvision.transforms: Commonly used image transformations, such as cropping, rotation, etc.;
  4. torchvision.utils: Some other useful methods

3.5.1. Get the data set

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l

Below, we torchvision.datasetdownload this data set through torchvision . The data will be automatically downloaded from the Internet the first time it is called. We use parameters trainto specify to obtain training set or testing data set (testing data).

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

If the data set download is slow, it is recommended to download 4 through the website , and then import it according to the downloaded directory. As shown in the figure below, I downloaded (4 files, only one is mentioned here, others are similar) train-images-idx3-ubyte.gz to the local directory C:\Users\1/Datasets/FashionMNIST\FashionMNIST\raw\. Download 4 directly from the URL (no need to decompress) to the directory and execute the above code.

Insert picture description here

# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train))   # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test))   # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label)   # torch.Size([1, 28, 28]) tensor(5)

Insert picture description here
The variable featurecorresponds to an image whose height and width are both 28 pixels. Since we used it transforms.ToTensor(), the value of each pixel is a 32-bit floating point number [0.0, 1.0]. It should be noted that the size of the feature is (C * H * W), not (H * W * C). The first dimension is the number of channels, because the number of data channels is 1. The next two dimensions are the height and width of the image.

There are a total of 10 categories in Fashion-MNIST, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
#     text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    return [text_labels[int(i)] for i in labels]

# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    _, figs = plt.subplots(1, len(images), figsize= (12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

X, y = [], []
for i in range(10):
    # 从数据集中取出10个
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

Insert picture description here

3.5.2. Reading small batches

We will train the model on the training data set, and evaluate the performance of the model on the test set. As mentioned earlier, it mnist_trainis torch.utils.data.Dataseta subclass, so we can pass it in torch.utils.data.DataLoaderto create a DataLoader instance that reads small batches of data samples

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

start = time.time()
for X,y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

Insert picture description here

Guess you like

Origin blog.csdn.net/piano9425/article/details/107150167