在TensorFlow的源码中,MNIST数据集的读取操作在contrib\learn\python\learn\datasets\data\mnist.py中。
主要看第189行的read_data_sets函数:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000):
train_dir为数据集在文件夹的位置,在这里为tensorflow\examples\tutorials\mnist\MNIST_data;
在官方教程中提到fake_data标记是用于单元测试的,读者可以不必理会;
one_hot为one_hot编码,即独热码,作用是将状态值编码成状态向量,例如,数字状态共有0~9这10种,对于数字7,将它进行one_hot编码后为[0 0 0 0 0 0 0 1 0 0],这样使得状态对于计算机来说更加明确,对于矩阵操作也更加高效。
dtype的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]。
reshape的作用是将图像的形状从[num examples, rows, columns, depth]转变为[num examples, rows*columns] (对于二维图片,depth为1)。
validation_size即为从训练集中抽取这么多来作为验证集。
变量定义好之后,接下来提取数据集。
先是图片文件:
with open(local_file, 'rb') as f:
train_images = extract_images(f)
看extract_images函数,从第52行开始:
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
(magic, f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
如果这么看代码可能很难理解,但是如果清楚MNIST数据集文件的结构之后就好理解得多,对于MNIST的images文件:
offset | type | value | description |
0000 | 32 bit integer | 0x00000803(2051) | magic number |
0004 | 32 bit integer | 60000 | number of images |
0008 | 32 bit integer | 28 | number of rows |
0012 | 32 bit integer | 28 | number of columns |
0016 | unsigned byte | ?? | pixel |
0017 | unsigned byte | ?? | pixel |
0018 | unsigned byte | ?? | pixel |
...... | |||
xxxx | unsigned byte | ?? | pixel |
代码中_read32()的定义在第33行,作用是从文件流中动态读取4位数据并转换为uint32的数据。
image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再将读取到的所有像素值装换为[index, rows, cols, depth]的4D矩阵。这样就将全部的image数据读取了出来。
同理,对于MNIST的labels文件:
offset | type | value | description |
0000 | 32 bit integer | 0x00000801(2049) | magic number |
0004 | 32 bit integer | 60000 | number of items |
0008 | unsigned byte | ?? | label |
0009 | unsigned byte | ?? | label |
...... | |||
xxxx | unsigned byte | ?? | label |
再看代码,从第90行开始:
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels, num_classes)
return labels
同样的也是依次读取文件的魔术码以及标签总数,最后把所有图片的标签读取出来,成一个长度为num_items的1D的向量。不过代码中还有一个one_hot的部分,dense_to_one_hot的代码为:
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
正如文章开头提到one_hot的作用,这里将1D向量中的每一个值,编码成一个长度为num_classes的向量,向量中对应于该值的位置为1,其余为0,所以one_hot将长度为num_labels的向量编码为一个[num_labels, num_classes]的2D矩阵。
以上就是如何将MNIST数据文件中的images和labels分别提取出来的过程,与TensorFlow和deeplearning无关,但是我觉得对于MNIST数据集的了解,以及后面的一些才做还是很有帮助的。