使用python处理MNIST数据集

一. MNIST数据集

1.1 什么是MNIST数据集

MNIST数据集是入门机器学习/识别模式的最经典数据集之一。最早于1998年Yan Lecun在论文:[Gradient-based learning applied to document recognition]中提出。该数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。每张图片中的像素大小在0-255之间,其中0是黑色,255是白色。如下图所示:

在这里插入图片描述

MNIST共包含70000张手写数字图片,其中有60000张用作训练集,10000张用作测试集。元数据集可以在MNIST官网下载。下载之后得到4个压缩文件:

train-images-idx3-ubyte.gz #60000张训练集图片
train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签
t10k-images-idx3-ubyte.gz  #10000张测试集图片
t10k-labels-idx1-ubyte.gz  #10000张测试集图片对应的标签

将上面的4个压缩文件分别解压,得到:

train-images-idx3-ubyte #60000张训练集图片的idx3-ubyte格式文件
train-labels-idx1-ubyte #60000张训练集图片对应的标签的idx3-ubyte格式文件
t10k-images-idx3-ubyte  #10000张测试集图片的idx3-ubyte格式文件
t10k-labels-idx1-ubyte  #10000张测试集图片对应的标签的idx3-ubyte格式文件

1.2MNIST数据集文件格式

解压得到的4个文件都是二进制格式的文件,为了获取其中的信息,需要先了解MNIST二进制文件的存储格式。格式描述如下:
在这里插入图片描述

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;
  • 第5-8个byte存的是number of images,即图像数量60000;
  • 第9-12个byte存的是每张图片行数/高度,即28;
  • 第13-16个byte存的是每张图片的列数/宽度,即28。
  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

1.3使用python访问MNIST数据集文件内容

知道了MNIST二进制文件的存储方式,下面介绍如何使用python访问文件内容。同样以训练集标签文件train-labels-idx1-ubyte和训练集图像文件train-images-idx3-ubyte为例:

import numpy as np
from PIL import Image

MNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址

with open(MNIST_labels_path, 'rb') as f:
    file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:
    file_images = f.read()  # 读入照片二进制文件

magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)

magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)

使用with open() as 函数读取文件,并使用int.from_bytes()方法将文件的magic number, number of items, number of images, number of rows, number of columns,等数据读入,将字节数据转换成整数数据,从而查看图像数量、图像高度和图像宽度信息。
运行结果:

在这里插入图片描述

通过以下程序,可以将MNIST数据集二进制文件中的照片提取出来并以.png格式保存在文件夹中:

# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):
    image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]
    image_np = np.array(image, dtype=np.uint8).reshape(28, 28)
    im = Image.fromarray(image_np)
    im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")

输出的部分照片如下所示:
在这里插入图片描述

通过以下程序,将二进制标签文件中的部分标签信息打印出来,可以发现,标签中的数据正对应于图像中的手写数字信息。

# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):
    labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')
    print('labels' + str(i) + '=' + str(labels))

在这里插入图片描述

附录

程序源码

import numpy as np
from PIL import Image

MNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址

with open(MNIST_labels_path, 'rb') as f:
    file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:
    file_images = f.read()  # 读入照片二进制文件

magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)

magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)

# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):
    image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]
    image_np = np.array(image, dtype=np.uint8).reshape(28, 28)
    im = Image.fromarray(image_np)
    im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")

# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):
    labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')
    print('labels' + str(i) + '=' + str(labels))

猜你喜欢

转载自blog.csdn.net/qq_30150579/article/details/133068622
今日推荐