一、MNSIT数据处理
MNSIT是一个非常有名的手写体数字识别数据集。包含60000张训练图片,10000张测试图片。每张图片是28X28的数字。
TonserFlow提供了一个类来处理 MNSIT数据。这个类会自动下载并转化数据结构。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist_data = input_data.read_data_sets("mnist_set",one_hot=True) # print training data size print("training_data_size",mnist_data.train.num_examples) # print validation data size print("validating_data_size",mnist_data.validation.num_examples) #print testing data size print("testing data size",mnist_data.test.num_examples) print("example train image :",mnist_data.train.images[0]) print("example train label :",mnist_data.train.labels[0])
为了方便使用随机梯度下降,
input_data.read_data_sets还提供train.next_batch函数
batch_size = 100 train_x ,train_y = mnist_data.train.next_batch(batch_size) print("X_shape",train_x.shape) print("Y_shape",train_y.shape) ## #X_shape (100, 784) #Y_shape (100, 10)
二、神经网络模型训练及不同模型效果的对比
1.TF训练神经网络