MNIST数据集的格式以及读取方式

版权声明:版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DarrenXf/article/details/85232255

MNIST 网站
http://yann.lecun.com/exdb/mnist/

四个文件

train-images-idx3-ubyte.gz:  training set images (9912422 bytes) 
train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) 
t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes) 
t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

下下来后 解压

$ gunzip *.gz

t10k-images-idx3-ubyte
train-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-labels-idx1-ubyte

解压后会生成上面的四个文件

文件的格式

There are 4 files:

train-images-idx3-ubyte: training set images 
train-labels-idx1-ubyte: training set labels 
t10k-images-idx3-ubyte:  test set images 
t10k-labels-idx1-ubyte:  test set labels

The training set contains 60000 examples, and the test set 10000 examples.

The first 5000 examples of the test set are taken from the original NIST training set. The last 5000 are taken from the original NIST test set. The first 5000 are cleaner and easier than the last 5000.

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  10000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  10000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black). 

图像文件的前16个字节是头,包含了4个字节的幻数,4个字节表示图像数量,4个字节表示单个图像的行数,4个字节表示单个图像的列数.
标记文件的前8个字节是头,包含了4个字节的幻数,4个字节表示标记数量.

下面读取文件

from __future__ import division                                                 
from __future__ import print_function                                           
                                                                                
#gunzip *.gz                                                                    
#http://yann.lecun.com/exdb/mnist/                                              
                                                                                
import os                                                                       
import sys                                                                      
import struct                                                                   
                                                                                
file_list = [                                                                   
            "train-images-idx3-ubyte",                                          
            "train-labels-idx1-ubyte",                                          
            "t10k-images-idx3-ubyte",                                           
            "t10k-labels-idx1-ubyte",                                           
            ]                                                                   
                                                                                
def create_path(path):                                                          
    if not os.path.isdir(path):                                                 
        os.makedirs(path)                                                       
                                                                                
def get_file_full_name(path, name):                                             
    create_path(path)                                                           
    if path[-1] == "/":                                                         
        full_name = path +  name                                                
    else:                                                                       
        full_name = path + "/" +  name                                          
    return full_name                                                            
                                                                                
def read_mnist(file_name):                                                             
    file_path = "/home/your/data/path"                                         
    full_path = get_file_full_name(file_path, file_name)                        
    file_object = open(full_path, 'rb')  #python3 need rb  python2 r is ok      
    return file_object
    
 def get_file_header_data(file_name, header_len, unpack_str):                    
    f = read_mnist(file_name)                                                   
    raw_header = f.read(header_len)                                             
    header_data = struct.unpack(unpack_str, raw_header)                         
    return header_data                                                          
                                                                                
def show_images_file_header(file_name):                                         
    show_file_header(file_name, 16, ">4I")                                      
                                                                                
def show_labels_file_header(file_name):                                         
    show_file_header(file_name, 8, ">2I")                                       
                                                                                
def show_file_header(file_name, header_len, unpack_str):                        
    header_data = get_file_header_data(file_name, header_len, unpack_str)       
    print("%s header data:%s" % (file_name, header_data))                       
                                                                                
def show_mnist_file_header():                                                   
    train_images_file_name = file_list[0]                                       
    show_images_file_header(train_images_file_name)                             
                                                                                
    test_images_file_name = file_list[2]                                        
    show_images_file_header(test_images_file_name)                              
                                                                                
    train_labels_file_name = file_list[1]                                       
    show_labels_file_header(train_labels_file_name)                             
                                                                                
    test_labels_file_name = file_list[3]                                        
    show_labels_file_header(test_labels_file_name)                              
                                                                                
def run():                                                                      
    show_mnist_file_header()                                                    
                                                                                
run()              

输出

train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
t10k-images-idx3-ubyte header data:(2051, 10000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
t10k-labels-idx1-ubyte header data:(2049, 10000)

下面我问读取一张图片 并且展示一张图片和它的标记

from __future__ import division                                                 
from __future__ import print_function                                           
                                                                                
#gunzip *.gz                                                                    
#http://yann.lecun.com/exdb/mnist/                                              
                                                                                
import os                                                                       
import sys                                                                      
import struct                                                                   
import numpy as np                                                              
import matplotlib.pyplot as plt                                                 
from PIL import Image                                                           
                                                                                
file_list = [                                                                   
            "train-images-idx3-ubyte",                                          
            "train-labels-idx1-ubyte",                                          
            "t10k-images-idx3-ubyte",                                           
            "t10k-labels-idx1-ubyte",                                           
            ]                                                                   
                                                                                
def create_path(path):                                                          
    if not os.path.isdir(path):                                                 
        os.makedirs(path)                                                       
                                                                                
def get_file_full_name(path, name):                                             
    create_path(path)                                                           
    if path[-1] == "/":                                                         
        full_name = path +  name                                                
    else:                                                                       
        full_name = path + "/" +  name                                          
    return full_name                                                            
                                                                                
def read_mnist(file_name):                                                             
    file_path = "/home/your/data/path"                                         
    full_path = get_file_full_name(file_path, file_name)                        
    file_object = open(full_path, 'rb')  #python3 need rb  python2 r is ok         
    return file_object                                                          
                                                                                
def get_file_header_data(file_obj, header_len, unpack_str):                     
    raw_header = file_obj.read(header_len)                                      
    header_data = struct.unpack(unpack_str, raw_header)                         
    return header_data     
def show_images_file_header(file_name):                                         
    show_file_header(file_name, 16, ">4I")                                      
                                                                                
def show_labels_file_header(file_name):                                         
    show_file_header(file_name, 8, ">2I")                                       
                                                                                
def show_file_header(file_name, header_len, unpack_str):                        
    file_obj = read_mnist(file_name)                                            
    header_data = get_file_header_data(file_obj, header_len, unpack_str)        
    show_file_header_data(file_name, header_data)                               
    file_obj.close()                                                            
                                                                                
def show_mnist_file_header():                                                   
    train_images_file_name = file_list[0]                                       
    show_images_file_header(train_images_file_name)                             
                                                                                
    test_images_file_name = file_list[2]                                        
    show_images_file_header(test_images_file_name)                              
                                                                                
    train_labels_file_name = file_list[1]                                       
    show_labels_file_header(train_labels_file_name)                             
                                                                                
    test_labels_file_name = file_list[3]                                        
    show_labels_file_header(test_labels_file_name)                              
                                                                                
def read_a_image(file_object):                                                  
    img = file_object.read(28*28)                                               
    tp = struct.unpack(">784B",img)                                             
    image = np.asarray(tp)                                                      
    image = image.reshape((28,28))                                              
    #image = image.astype(np.float64)                                           
    plt.imshow(image,cmap = plt.cm.gray)                                        
    plt.show()                                                                  
                                                                                
def read_a_label(file_object):                                                  
    img = file_object.read(1)                                                   
    tp = struct.unpack(">B",img)                                                
    print("the label is :%s" % tp[0])                                           
                                                                                
def show_file_header_data(file_name,header_data):                               
    print("%s header data:%s" % (file_name, header_data))
    
def show_a_image():                                                             
    images_file_name = file_list[0]                                             
    labels_file_name = file_list[1]                                             
    images_file = read_mnist(images_file_name)                                  
    header_data = get_file_header_data(images_file, 16, ">4I")                  
    show_file_header_data(images_file_name, header_data)                        
                                                                                
    labels_file = read_mnist(labels_file_name)                                  
    header_data = get_file_header_data(labels_file, 8, ">2I")                   
    show_file_header_data(labels_file_name, header_data)                        
                                                                                
    read_a_image(images_file)                                                   
    read_a_label(labels_file)                                                   
                                                                                                                                                                       
def run():                                                                      
    #show_mnist_file_header()                                                   
    show_a_image()       
                                                                                                                                                                                          
run()

输出

train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
the label is :5

然后图片
在这里插入图片描述

恩 图片和标记一样是5

然后我们修改成能自动生成批数据

from __future__ import division    
from __future__ import print_function    
    
#gunzip *.gz    
#http://yann.lecun.com/exdb/mnist/    
    
import os    
import sys    
import struct    
import numpy as np    
import matplotlib.pyplot as plt    
from PIL import Image    
    
file_list = [    
            "train-images-idx3-ubyte",    
            "train-labels-idx1-ubyte",    
            "t10k-images-idx3-ubyte",    
            "t10k-labels-idx1-ubyte",    
            ]    
    
def show_images_file_header(file_name):    
    show_file_header(file_name, 16, ">4I")    
    
def show_labels_file_header(file_name):    
    show_file_header(file_name, 8, ">2I")    
    
def show_file_header(file_name, header_len, unpack_str):    
    file_obj = read_mnist(file_name)    
    header_data = get_file_header_data(file_obj, header_len, unpack_str)        
    show_file_header_data(file_name, header_data)    
    file_obj.close()  
    
def show_mnist_file_header():    
    train_images_file_name = file_list[0]    
    show_images_file_header(train_images_file_name)    

    test_images_file_name = file_list[2]
    show_images_file_header(test_images_file_name)

    train_labels_file_name = file_list[1]
    show_labels_file_header(train_labels_file_name)

    test_labels_file_name = file_list[3]
    show_labels_file_header(test_labels_file_name)

def show_a_image(file_object):
    image = read_a_image(images_file)
    image = np.asarray(tp)
    image = image.reshape((28,28))
    plt.imshow(image,cmap = plt.cm.gray)
    plt.show()

def show_a_lebel(file_object):
    tp = read_a_label(file_object)
    print("the label is :%s" % tp)

def show_file_header_data(file_name,header_data):
    print("%s header data:%s" % (file_name, header_data))

def show_a_image():
    images_file_name = file_list[0]
    labels_file_name = file_list[1]
    images_file = read_mnist(images_file_name)
    header_data = get_file_header_data(images_file, 16, ">4I")
    show_file_header_data(images_file_name, header_data)

    labels_file = read_mnist(labels_file_name)
    header_data = get_file_header_data(labels_file, 8, ">2I")
    show_file_header_data(labels_file_name, header_data)
    
    show_a_image(images_file)
    read_a_label(labels_file)

def create_path(path):
    if not os.path.isdir(path):
        os.makedirs(path)

def get_file_full_name(path, name):
    create_path(path)
    if path[-1] == "/":
        full_name = path +  name
    else:
        full_name = path + "/" +  name
    return full_name

def read_mnist(file_name):         
    file_path = "/home/your/data/path"
    full_path = get_file_full_name(file_path, file_name)
    file_object = open(full_path, 'rb')  #python3 need rb  python2 r is ok      
    return file_object

def get_file_header_data(file_obj, header_len, unpack_str):
    raw_header = file_obj.read(header_len)
    header_data = struct.unpack(unpack_str, raw_header)
    return header_data

def read_a_image(file_object):
    raw_img = file_object.read(28*28)
    img = struct.unpack(">784B",raw_img)
    return img

def read_a_label(file_object):
    raw_label = file_object.read(1)
    label = struct.unpack(">B",raw_label)
    return label
def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
    images_file = read_mnist(images_file_name)
    header_data = get_file_header_data(images_file, 16, ">4I")
    #show_file_header_data(images_file_name, header_data)

    labels_file = read_mnist(labels_file_name)
    header_data = get_file_header_data(labels_file, 8, ">2I")
    #show_file_header_data(labels_file_name, header_data)

    while True:
        images = []
        labels = []
        for i in range(100):
            try:
                image = read_a_image(images_file)
                label = read_a_label(labels_file)
                images.append(image)
                labels.append(label)
            except Exception as err:
                print(err)
                break
        yield images,labels

def get_train_data_generator():
    images_file_name = file_list[0]
    labels_file_name = file_list[1]
    gennerator = generate_a_batch(images_file_name,labels_file_name)
    return gennerator-

def get_test_data_generator():
    images_file_name = file_list[2]
    labels_file_name = file_list[3]
    gennerator = generate_a_batch(images_file_name,labels_file_name)
    return gennerator
    
def get_test_data_generator():
    images_file_name = file_list[2]
    labels_file_name = file_list[3]
    gennerator = generate_a_batch(images_file_name,labels_file_name)
    return gennerator-

def get_a_batch(data_generator):
    if sys.version >'3':
        batch_img, batch_labels = data_generator.__next__()
    else:
        batch_img, batch_labels = data_generator.next()
    return batch_img,batch_labels

def generate_test_batch():
    data_generator = get_test_data_generator()
    count = 1
    while count:
        batch_img,batch_labels = get_a_batch(data_generator)
        if not batch_img and not batch_labels:
            break
        batch_img = np.array(batch_img)
        batch_labels = np.array(batch_labels)
        print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
        count +=1
        
def generate_train_batch():
    epoch = 0
    while epoch<10:
        epoch += 1
        data_generator = get_train_data_generator()
        count = 1
        while count:
            batch_img,batch_labels = get_a_batch(data_generator)
            if not batch_img and not batch_labels:
                break
            batch_img = np.array(batch_img)
            batch_labels = np.array(batch_labels)
            print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
            count +=1

def run():
    generate_train_batch()
    generate_test_batch()

run()

上面的格式里好多没有用的代码 把没有用的代码删掉
我们得到

from __future__ import division
from __future__ import print_function

#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/

import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

file_list = [
            "train-images-idx3-ubyte",
            "train-labels-idx1-ubyte",
            "t10k-images-idx3-ubyte",
            "t10k-labels-idx1-ubyte",
            ]

def create_path(path):
    if not os.path.isdir(path):
        os.makedirs(path)

def get_file_full_name(path, name):
    create_path(path)
    if path[-1] == "/":
        full_name = path +  name
    else:
        full_name = path + "/" +  name
    return full_name

def read_mnist(file_name):
    file_path = "/home/your/data/path"
    full_path = get_file_full_name(file_path, file_name)
    file_object = open(full_path, 'rb')  #python3 need rb  python2 r is ok
    return file_object

def get_file_header_data(file_obj, header_len, unpack_str):
    raw_header = file_obj.read(header_len)
    header_data = struct.unpack(unpack_str, raw_header)
    return header_data

def read_a_image(file_object):
    raw_img = file_object.read(28*28)
    img = struct.unpack(">784B",raw_img)
    return img

def read_a_label(file_object):
    raw_label = file_object.read(1)
    label = struct.unpack(">B",raw_label)
    return label

def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
    images_file = read_mnist(images_file_name)
    header_data = get_file_header_data(images_file, 16, ">4I")
    labels_file = read_mnist(labels_file_name)
    header_data = get_file_header_data(labels_file, 8, ">2I")

    while True:
        images = []
        labels = []
        for i in range(100):
            try:
                image = read_a_image(images_file)
                label = read_a_label(labels_file)
                images.append(image)
                labels.append(label)
            except Exception as err:
                print(err)
                break
        yield images,labels

def get_train_data_generator():
    images_file_name = file_list[0]
    labels_file_name = file_list[1]
    gennerator = generate_a_batch(images_file_name,labels_file_name)
    return gennerator 

def get_test_data_generator():
    images_file_name = file_list[2]
    labels_file_name = file_list[3]
    gennerator = generate_a_batch(images_file_name,labels_file_name)
    return gennerator 

def get_a_batch(data_generator):
    if sys.version >'3':
        batch_img, batch_labels = data_generator.__next__()
    else:
        batch_img, batch_labels = data_generator.next()
    return batch_img,batch_labels

def generate_test_batch():
    data_generator = get_test_data_generator()
    count = 1
    while count:
        batch_img,batch_labels = get_a_batch(data_generator)
        if not batch_img and not batch_labels:
            break
        batch_img = np.array(batch_img)
        batch_labels = np.array(batch_labels)
        print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
        count +=1

def generate_train_batch():
    epoch = 0
    while epoch<10:
        epoch += 1
        data_generator = get_train_data_generator()
        count = 1
        while count:
            batch_img,batch_labels = get_a_batch(data_generator)
            if not batch_img and not batch_labels:
                break
            batch_img = np.array(batch_img)
            batch_labels = np.array(batch_labels)
            print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
            count +=1

def run():
    generate_train_batch()
    generate_test_batch()

run()

输出好长,输出就不贴上来了 以上代码兼容了Python2和Python3

猜你喜欢

转载自blog.csdn.net/DarrenXf/article/details/85232255