mnist数据集读取并保存为Numpy数组

背景信息

MNIST数据集简介

MNIST数据集是从 NIST 的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。由于SD-3是由美国人口调查局的员工进行标注,SD-1是由美国高中生进行标注,因此SD-3比SD-1更干净也更容易识别。Yann LeCun等人从SD-1和SD-3中各取一半作为MNIST的训练集(60000条数据)和测试集(10000条数据),其中训练集来自250位不同的标注员,此外还保证了训练集和测试集的标注员是不完全相同的。

本文目的

本文实现MNIST数据集和标签的读取,并转化为Numpy的数组进行输出。

前提条件

以完成MNIST数据集的下载,如下所示:

root@5e3ac72a80f4:~/.cache/paddle/dataset/mnist# ll
total 11344
drwxr-xr-x  2 root root    4096 Mar 12 03:22 ./
drwxr-xr-x 13 root root    4096 Apr  1 07:01 ../
-rw-r--r--  1 root root 1648877 Mar 12 03:22 t10k-images-idx3-ubyte.gz
-rw-r--r--  1 root root    4542 Mar 12 03:22 t10k-labels-idx1-ubyte.gz
-rw-r--r--  1 root root 9912422 Mar 12 03:22 train-images-idx3-ubyte.gz
-rw-r--r--  1 root root   28881 Mar 12 03:22 train-labels-idx1-ubyte.gz

详细代码

#导入所需包
import subprocess
import numpy
import platform
#定义变量
image_filename='/root/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz'
label_filename='/root/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz'
buffer_size=100
# 定义函数读取image,并保存为数组
def get_images(image_filename, buffer_size):
    m = subprocess.Popen(['zcat', image_filename], stdout=subprocess.PIPE)
    m.stdout.read(16)  # skip some magic bytes
    images=numpy.fromfile(m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape((buffer_size, 28 * 28)).astype('float32')
    images = images / 255.0 * 2.0 - 1.0
    m.terminate()
    return images
# 定义函数读取labels,并保存为数组
def get_labels(label_filename, buffer_size):
    l = subprocess.Popen(['zcat', label_filename], stdout=subprocess.PIPE)
    l.stdout.read(8)  # skip some magic bytes
    labels = numpy.fromfile(l.stdout, 'ubyte', count=buffer_size).astype("int")
    #print labels.shape
    l.terminate()
    return labels
# 创建Paddle中使用的def reader_create(image_filename, label_filename, buffer_size)
def mnist_reader(image_filename, label_filename, buffer_size):
    def reader():
        images=get_images(image_filename, buffer_size)
        labels=get_labels(label_filename, buffer_size)
        for i in xrange(buffer_size):
            yield images[i,:], int(labels[i])
    return reader

查看结果:




猜你喜欢

转载自blog.csdn.net/wiborgite/article/details/79785167