MXNet动手学深度学习:Fashion-MNIST数据集读取

获取数据集

数据集简介

本节中将使用数据集Fashion-MNIST,Fashion-MNIST 中⼀共包括了 10 个类别,分别为:t-shirt(T 恤)、trouser(裤⼦)、pullover(套衫)、dress(连⾐裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和 ankle boot(短靴)。在该数据集中含有6000个训练集样本、1000个测试集样本。
为什么不用手写数据集做实验了呢?
原因为:
1、MNIST is too easy. Convolutional nets can achieve 99.7% on MNIST,太简单
2、MNIST is overused,过度使用
3、MNIST can not represent modern CV tasks, 过时
在这里插入图片描述

代码讲解

导入包

import gluonbook as gb
from mxnet.gluon import data as gdata
import sys
import time

获取数据集与测试集

#获取数据集与测试集
#MXNet的gdata.vision提供了FashionMNIST的数据集
mnist_train = gdata.vision.FashionMNIST(train = True)
mnist_test = gdata.vision.FashionMNIST(train = False)

将数值标签转化为文本标签

因为数据集含有10个类别,在数据中的标签为0-9,所以在图像显示时,我们想更加方便的知道图像对应的的类别名字,而不是冰冷的数字,那么就可以采用一下代码,将数字标签转化为与其对应的文本标签。

#将数值标签转化为文本标签
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

显示图像的函数

在使用代码的时候,难免想观察一下图像效果,通过该函数,可以显示出图像与其对应的标签。

def show_fashion_mnist(images, labels):
    gb.use_svg_display()
    _, figs = gb.plt.subplots(1, len(images), figsize = (12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.reshape((28, 28)).asnumpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

测试 show_fashion_mnist函数

由于本人使用的为服务器,没有界面显示。这里就不贴显示的效果了,对于没有界面显示的小伙伴,可以采用将图片进行保存,然后传到本地机器,进行查看。服务器与本地机器之间传文件的指令为:rz,sz。不懂的小伙伴可以查看这篇文章:Linux与windows文件传输

#显示训练集中前10张图片及其对应的文字标签
X, y = mnist_train[0:9]
show_fashion_mnist(X, get_fashion_mnist_labels(y))

读取小批量

ToTensor 类将图像数据从 uint8 格式变换成 32 位浮点数格式,并除以 255 使得所有像素的数值均在 0 到 1 之间。ToTensor 类还将图像通道从最后⼀维移到最前⼀维来⽅便之后介绍的卷积神经⽹络计算。

transformer = gdata.vision.transforms.ToTensor()
#num_workers参数用来设置读取数据的进程数,目的是:用多进程加速数据的读取。
if sys.platform.startswith('win'):
    num_workers = 0
else:
    num_workers = 4
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),
                            batch_size, shuffle = True,
                            num_workers = num_workers) #num_workers设置进程数 
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),
                            batch_size, shuffle = False,
                            num_workers = num_workers)

train_iter和test_iter这两个变量被保存到了gluonbook.load_data_fashion_mnist函数中了,在后面的实现中可用使用该函数,会返回这两个变量。。。
gluonbook.load_data_fashion_mnist这个函数具体是怎样实现的?目前还不知道,书中写到后续会有介绍
至此,数据的读取已经完成,包括训练集与测试集的读取

代码

#图像分类数据集
import gluonbook as gb
from mxnet.gluon import data as gdata
import sys
import time

#获取数据集与测试集
#MXNet的gdata.vision提供了FashionMNIST的数据集
mnist_train = gdata.vision.FashionMNIST(train = True)
mnist_test = gdata.vision.FashionMNIST(train = False)

print(len(mnist_train)) #训练集样本数
print(len(mnist_test)) #测试集样本数

feature, label = mnist_train[0]
print(feature.shape) #图像是28*28的灰度图像
print(label.dtype) # label的类型

#将数值标签转化为文本标签
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
#显示图像...由于我使用的是服务器,没有界面显示。所以运行显示图片的程序的时候会出现问题
def show_fashion_mnist(images, labels):
    gb.use_svg_display()
    _, figs = gb.plt.subplots(1, len(images), figsize = (12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.reshape((28, 28)).asnumpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

#显示训练集中前10张图片及其对应的文字标签
X, y = mnist_train[0:9]
show_fashion_mnist(X, get_fashion_mnist_labels(y))

batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
#num_workers参数用来设置读取数据的进程数,目的是:用多进程加速数据的读取。
if sys.platform.startswith('win'):
    num_workers = 0
else:
    num_workers = 4

train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),
                            batch_size, shuffle = True,
                            num_workers = num_workers) #num_workers设置进程数 

test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),
                            batch_size, shuffle = False,
                            num_workers = num_workers)
start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start)) #打印显示读取数据所需时间                         

总结

通过本章的学习,了解了数据集的读取,但是这个数据集的读取,读取的为gluon里已经含有的数据集,如果是没有包含的数据集应该是不能这样读取的。学习总是有个过程的,先知道这一种也是一种不错的体验,希望以后MXNet能够越来越好,能够直接调用更多的数据集。
后续,将会对softmax回归进行学习。

猜你喜欢

转载自blog.csdn.net/xiaobiyin9140/article/details/84749253