深度学习4:使用MNIST数据集(tensorflow)

本文将介绍MNIST数据集的数据格式和使用方法,使用到的是tensorflow中封装的类,包含代码。

这里写图片描述

MNIST数据集来源于这里, 如果希望下载原始格式的数据集,可以从这里下载。而本文中讲解的是已经使用python代码封装好的MNIST数据集。封装的代码作为tensorflow的一部分,内部使用了numpy。所以,在使用这段封装的代码的时候,返回值将numpy的对象。因为numpy是常用和非常实用的工具包,所以,使用numpy具有很多好处。

首先导入包:

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

获取MNIST数据集主需要执行下面一行代码:

mnist = input_data.read_data_sets('dir_with_mnist_data_files', one_hot=False)

mnist数据集把数据分成了三部分,分别是train、test和validation,作用分别是训练、测试和认证。有时候test不是必须的。它们的数据集的大小可以通过下面这段代码得到:

    print("train example number: ", mnist.train.num_examples)
    print("train image shape: ", mnist.train.images.shape)
    # print(mnist.train.images[0:1])
    print("train label shape: ", mnist.train.labels.shape)
    print("test image shape: ", mnist.test.images.shape)
    print("test image label shape: ", mnist.test.labels.shape)
    print("validataion image shape: ", mnist.validation.images.shape)
    print("validataion label shape: ", mnist.validation.labels.shape)

    batch_xs, batch_ys = mnist.train.next_batch(batch_size=100)
    print("batch_xs shape is , batch_ys shape is ", batch_xs.shape, batch_ys.shape)

    print("max value in image: ", np.max(mnist.test.images))
    print("min value in image", np.min(mnist.test.images))
    print("max value in label", np.max(mnist.test.labels))
    print("min value in label", np.min(mnist.test.labels))

    one_hot_mnist = input_data.read_data_sets('dir_with_mnist_data_files', one_hot=True)
    print("train image shape of one_hot mnist shape: ", one_hot_mnist.train.images.shape)
    print("train label shape of one_hot mnist shape: ", one_hot_mnist.train.labels.shape)

上面这段代码的执行结果为:

train example number:  55000
train image shape:  (55000, 784)
train label shape:  (55000,)
test image shape:  (10000, 784)
test image label shape:  (10000,)
validataion image shape:  (5000, 784)
validataion label shape:  (5000,)
batch_xs shape is , batch_ys shape is  (100, 784) (100,)
max value in image:  1.0
min value in image 0.0
max value in label 9
min value in label 0
train image shape of one_hot mnist shape:  (55000, 784)
train label shape of one_hot mnist shape:  (55000, 10)

总结如下:
上面的代码对象比如mnist.train.num_examples是numpy对象。

有55000个train 用例,每一个用例的x是一个长度为28 x 28 = 784的数组,其中的最大值为1,最小值为0;每一个用例的y或者label是一个数,其中最大值为9, 最小值为0。

有10000个test用例。

有5000个validation用例。

如果是one hot格式,也就是上面的代码中的one_hot参数为True,那么每一个用例中的label不再是一个数,而是一个长度为10的数组,该数组中有9个0,只有一个1(注意是一个用例)。比如如果是手写字3, 那么对于one hot格式,就变成[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。one hot编码格式在多分类模型中使用到。其他的格式不变。

结束! 谢谢

猜你喜欢

转载自blog.csdn.net/liangyihuai/article/details/78968730