Using python to process the MNIST data set

1. MNIST data set

1.1 What is the MNIST data set

The MNIST dataset is one of the most classic datasets for introductory machine learning/pattern recognition. It was first proposed in 1998 by Yan Lecun in the paper: [Gradient-based learning applied to document recognition]. This data set contains 10 types of handwritten digit images from 0 to 9. Each image has been size normalized and is a grayscale image of 28x28 size. The pixel size in each image ranges from 0-255, where 0 is black and 255 is white. As shown below:

Insert image description here

MNIST contains a total of 70,000 images of handwritten digits, of which 60,000 are used as training sets and 10,000 are used as test sets. The metadata set can be downloaded from the MNIST official website. After downloading, you get 4 compressed files:

train-images-idx3-ubyte.gz #60000张训练集图片
train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签
t10k-images-idx3-ubyte.gz  #10000张测试集图片
t10k-labels-idx1-ubyte.gz  #10000张测试集图片对应的标签

Unzip the four compressed files above to get:

train-images-idx3-ubyte #60000张训练集图片的idx3-ubyte格式文件
train-labels-idx1-ubyte #60000张训练集图片对应的标签的idx3-ubyte格式文件
t10k-images-idx3-ubyte  #10000张测试集图片的idx3-ubyte格式文件
t10k-labels-idx1-ubyte  #10000张测试集图片对应的标签的idx3-ubyte格式文件

1.2MNIST data set file format

The four files obtained after decompression are all in binary format. In order to obtain the information, you need to first understand the storage format of the MNIST binary file. The format is described as follows:
Insert image description here

  • The first to fourth bytes (bytes, 1byte=8bit), that is, the first 32 bits store the magic number of the file, and the corresponding decimal size is 2051;
  • The 5th to 8th bytes store the number of images, that is, the number of images is 60,000;
  • The 9th-12th bytes store the number of lines/height of each picture, which is 28;
  • The 13th-16th bytes store the number of columns/width of each image, which is 28.
  • Starting from the 17th byte, each byte stores the value of one pixel in an image.

1.3 Use python to access the MNIST data set file content

Now that we know how MNIST binary files are stored, here's how to use python to access the file contents. Also take the training set label file train-labels-idx1-ubyteand training set image file train-images-idx3-ubyteas an example:

import numpy as np
from PIL import Image

MNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址

with open(MNIST_labels_path, 'rb') as f:
    file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:
    file_images = f.read()  # 读入照片二进制文件

magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)

magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)

magic number, number of items, number of images, number of rows, number of columns,Use the with open() as function to read the file, and use the int.from_bytes() method to read the file's other data, convert the byte data into integer data, and view the image number, image height, and image width information.
operation result:

Insert image description here

Photos from the MNIST dataset binary file can be extracted and saved in a folder in .png format with the following program:

# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):
    image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]
    image_np = np.array(image, dtype=np.uint8).reshape(28, 28)
    im = Image.fromarray(image_np)
    im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")

Some of the output photos are shown below:
Insert image description here

Through the following program, part of the label information in the binary label file is printed out. It can be found that the data in the label corresponds to the handwritten digital information in the image.

# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):
    labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')
    print('labels' + str(i) + '=' + str(labels))

Insert image description here

appendix

Program source code

import numpy as np
from PIL import Image

MNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址

with open(MNIST_labels_path, 'rb') as f:
    file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:
    file_images = f.read()  # 读入照片二进制文件

magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)

magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)

# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):
    image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]
    image_np = np.array(image, dtype=np.uint8).reshape(28, 28)
    im = Image.fromarray(image_np)
    im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")

# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):
    labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')
    print('labels' + str(i) + '=' + str(labels))

Guess you like

Origin blog.csdn.net/qq_30150579/article/details/133068622