tensorflow tutorials(八):手写数字数据集MNIST介绍

声明:版权所有,转载请联系作者并注明出处:  http://blog.csdn.net/u013719780?viewmode=contents


在做机器学习相关实验的时候,首先我们就是需要一份通用的数据集,以便与其他的算法得到的实验结果进行比较。在图像分类领域MNIST数据集就是这样一个通用的数据集,前面几篇博文都用到了MNIST数据集,本文对其进行一些简单的介绍!


MNIST

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

%matplotlib inline  
print ("packs loaded")
packs loaded

Download and Extract MNIST dataset

In [2]:
print ("Download and Extract MNIST dataset")
mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)
print
print (" tpye of 'mnist' is %s" % (type(mnist)))
print (" number of trian data is %d" % (mnist.train.num_examples))
print (" number of test data is %d" % (mnist.test.num_examples))
Download and Extract MNIST dataset
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

 tpye of 'mnist' is <class 'collections.Datasets'>
 number of trian data is 55000
 number of test data is 10000
In [3]:
# What does the data of MNIST look like? 
print ("What does the data of MNIST look like?")
trainimg   = mnist.train.images
trainlabel = mnist.train.labels
testimg    = mnist.test.images
testlabel  = mnist.test.labels
print
print (" type of 'trainimg' is %s"    % (type(trainimg)))
print (" type of 'trainlabel' is %s"  % (type(trainlabel)))
print (" type of 'testimg' is %s"     % (type(testimg)))
print (" type of 'testlabel' is %s"   % (type(testlabel)))
print (" shape of 'trainimg' is %s"   % (trainimg.shape,))
print (" shape of 'trainlabel' is %s" % (trainlabel.shape,))
print (" shape of 'testimg' is %s"    % (testimg.shape,))
print (" shape of 'testlabel' is %s"  % (testlabel.shape,))
What does the data of MNIST look like?

 type of 'trainimg' is <type 'numpy.ndarray'>
 type of 'trainlabel' is <type 'numpy.ndarray'>
 type of 'testimg' is <type 'numpy.ndarray'>
 type of 'testlabel' is <type 'numpy.ndarray'>
 shape of 'trainimg' is (55000, 784)
 shape of 'trainlabel' is (55000, 10)
 shape of 'testimg' is (10000, 784)
 shape of 'testlabel' is (10000, 10)
In [4]:
# How does the training data look like?
print ("How does the training data look like?")
nsample = 5
randidx = np.random.randint(trainimg.shape[0], size=nsample)

for i in randidx:
    curr_img   = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix 
    curr_label = np.argmax(trainlabel[i, :] ) # Label
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.title("" + str(i) + "th Training Data " 
              + "Label is " + str(curr_label))
    print ("" + str(i) + "th Training Data " 
           + "Label is " + str(curr_label))
How does the training data look like?
12118th Training Data Label is 5
46324th Training Data Label is 8
33th Training Data Label is 4
36491th Training Data Label is 3
6910th Training Data Label is 3
<img src=""" style="box-sizing: border-box; border: 0px; vertical-align: middle; max-width: 100%; height: auto;" alt="">
In [5]:
# Batch Learning? 
print ("Batch Learning? ")
batch_size = 100
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
print ("type of 'batch_xs' is %s" % (type(batch_xs)))
print ("type of 'batch_ys' is %s" % (type(batch_ys)))
print ("shape of 'batch_xs' is %s" % (batch_xs.shape,))
print ("shape of 'batch_ys' is %s" % (batch_ys.shape,))
Batch Learning? 
type of 'batch_xs' is <type 'numpy.ndarray'>
type of 'batch_ys' is <type 'numpy.ndarray'>
shape of 'batch_xs' is (100, 784)
shape of 'batch_ys' is (100, 10)
In [6]:
# Get Random Batch with 'np.random.randint'
print ("5. Get Random Batch with 'np.random.randint'")
randidx   = np.random.randint(trainimg.shape[0], size=batch_size)
batch_xs2 = trainimg[randidx, :]
batch_ys2 = trainlabel[randidx, :]
print ("type of 'batch_xs2' is %s" % (type(batch_xs2)))
print ("type of 'batch_ys2' is %s" % (type(batch_ys2)))
print ("shape of 'batch_xs2' is %s" % (batch_xs2.shape,))
print ("shape of 'batch_ys2' is %s" % (batch_ys2.shape,))
5. Get Random Batch with 'np.random.randint'
type of 'batch_xs2' is <type 'numpy.ndarray'>
type of 'batch_ys2' is <type 'numpy.ndarray'>
shape of 'batch_xs2' is (100, 784)
shape of 'batch_ys2' is (100, 10)
In [7]:
randidx
Out[7]:
array([51472, 13751, 33562, 23281,  8489, 48481,  7799, 30307, 37366,
       25312, 46149, 49712,  5083, 52853, 29819, 36444, 34829,  8769,
       39518, 54911,  6720, 43675, 41703, 35594,  9300, 14474, 33318,
       14808, 53456, 41978,  8047, 34524, 30978, 53455, 42119, 22660,
       30329, 27169, 53798,  2125, 41759, 38951,  1438, 33511, 38784,
       15822, 16785,  9229,  1216, 19569,  3116, 22172, 14766, 16153,
        1707, 20899,  9087, 21263, 24853, 27784, 38324, 29287, 21828,
       34511, 26340, 39194, 38272, 34238, 28050, 29294, 42672, 18696,
       17796, 48147, 41841, 47077,  5925, 48237, 30605,  9169, 11260,
        9155, 39346, 41049, 11342,   536,  5927, 11155, 40424, 33583,
       38991, 16569, 34801,   870, 20546, 25061, 17601,  4521, 24359,  4613])

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013719780/article/details/53815266

猜你喜欢

转载自blog.csdn.net/fdbvm/article/details/80984182