TensorFlow官方教程学习笔记(四)——MNIST数据集的读取

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wspba/article/details/54311566

在TensorFlow的源码中,MNIST数据集的读取操作在contrib\learn\python\learn\datasets\data\mnist.py中。


主要看第189行的read_data_sets函数:

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):

train_dir为数据集在文件夹的位置,在这里为tensorflow\examples\tutorials\mnist\MNIST_data;

在官方教程中提到fake_data标记是用于单元测试的,读者可以不必理会;

one_hot为one_hot编码,即独热码,作用是将状态值编码成状态向量,例如,数字状态共有0~9这10种,对于数字7,将它进行one_hot编码后为[0 0 0 0 0 0 0 1 0 0],这样使得状态对于计算机来说更加明确,对于矩阵操作也更加高效。

dtype的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]。

reshape的作用是将图像的形状从[num examples, rows, columns, depth]转变为[num examples, rows*columns] (对于二维图片,depth为1)。

validation_size即为从训练集中抽取这么多来作为验证集。


变量定义好之后,接下来提取数据集。

先是图片文件:

with open(local_file, 'rb') as f:
    train_images = extract_images(f)

看extract_images函数,从第52行开始:

with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                       (magic, f.name))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

如果这么看代码可能很难理解,但是如果清楚MNIST数据集文件的结构之后就好理解得多,对于MNIST的images文件:

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
offset type value description
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
0018 unsigned byte ?? pixel
......      
xxxx unsigned byte ?? pixel

代码中_read32()的定义在第33行,作用是从文件流中动态读取4位数据并转换为uint32的数据。

image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再将读取到的所有像素值装换为[index, rows, cols, depth]的4D矩阵。这样就将全部的image数据读取了出来。


同理,对于MNIST的labels文件:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
offset type value description
0000 32 bit integer 0x00000801(2049) magic number
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
......      
xxxx unsigned byte ?? label

再看代码,从第90行开始:

with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                       (magic, f.name))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels, num_classes)
    return labels

同样的也是依次读取文件的魔术码以及标签总数,最后把所有图片的标签读取出来,成一个长度为num_items的1D的向量。不过代码中还有一个one_hot的部分,dense_to_one_hot的代码为:

  num_labels = labels_dense.shape[0]
  index_offset = numpy.arange(num_labels) * num_classes
  labels_one_hot = numpy.zeros((num_labels, num_classes))
  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  return labels_one_hot

正如文章开头提到one_hot的作用,这里将1D向量中的每一个值,编码成一个长度为num_classes的向量,向量中对应于该值的位置为1,其余为0,所以one_hot将长度为num_labels的向量编码为一个[num_labels, num_classes]的2D矩阵。

以上就是如何将MNIST数据文件中的images和labels分别提取出来的过程,与TensorFlow和deeplearning无关,但是我觉得对于MNIST数据集的了解,以及后面的一些才做还是很有帮助的。

猜你喜欢

转载自blog.csdn.net/wspba/article/details/54311566