Tensorflow Mnist数据集

Tensorflow自带的Mnist数据集相关情况

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#数据会自动在线下载,第一次较慢,第二次之后就好了
mnist = input_data.read_data_sets('data/',one_hot=True)
print(type(mnist))
print(mnist.train.num_examples)#55000
print(mnist.test.num_examples)#10000

img_train = mnist.train.images
label_train = mnist.train.labels

img_test = mnist.test.images
label_test = mnist.test.labels

print(type(img_train))#<class 'numpy.ndarray'>
print(type(label_train))#<class 'numpy.ndarray'>
print(type(img_test))#<class 'numpy.ndarray'>
print(type(label_test))#<class 'numpy.ndarray'>
print(img_train.shape)#(55000, 784) 28*28的图片
print(label_train.shape)#(55000, 10)
print(img_test.shape)#(10000, 784)
print(label_test.shape)#(10000, 10) #one hot coding便于取最大概率

num_sample = 5
rand_idx = np.random.randint(img_train.shape[0], size=num_sample)

for i in rand_idx:
    cur_img = np.reshape(img_train[i, :],(28,28))
    cur_label = np.argmax(label_train[i,:])
    plt.matshow(cur_img, cmap = plt.get_cmap('gray'))
    print(str(i) + "训练数据的标签是" + str(cur_label))
    # plt.show()

#取batch数据
batch_size = 100
batch_x, batch_y = mnist.train.next_batch(batch_size)
print(type(batch_x))#<class 'numpy.ndarray'>
print(type(batch_y))#<class 'numpy.ndarray'>
print(batch_x.shape)#(100, 784)
print(batch_y.shape)#(100, 10)

猜你喜欢

转载自blog.51cto.com/5669384/2415956