Tensorflow 实战 Google 深度学习框架(第2版)---- 10.3.1节 P278 代码

#-*-coding:utf-8-*-
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.logging.set_verbosity(tf.logging.INFO)
mnist = input_data.read_data_sets('./mnist_data',one_hot = False)

feature_columns = [tf.feature_column.numeric_column('image',shape=[784])]

estimator = tf.estimator.DNNClassifier(
    feature_columns = feature_columns,
    hidden_units=[500],
    n_classes=10,
    optimizer=tf.train.AdamOptimizer(),
    model_dir = './mnist_data')

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'image':mnist.train.images},
    y=mnist.train.labels.astype(np.int32),
    num_epochs=None,
    batch_size=128,
    shuffle=True
)

estimator.train(input_fn=train_input_fn,steps=10000)

test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'image':mnist.test.images},
    y=mnist.test.labels.astype(np.int32),
    num_epochs = 1,
    batch_size=128,
    shuffle=False)

accuracy_score = estimator.evaluate(input_fn=test_input_fn)['accuracy']
print('\nTest accuracy:%g %%'%(accuracy_score*100))

 

猜你喜欢

转载自blog.csdn.net/Strive_For_Future/article/details/81561146