Tensorflow 实战 Google 深度学习框架(第2版)---- 10.3.2节 P280 代码

#-*-coding:utf-8-*-
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.logging.set_verbosity(tf.logging.INFO)

def lenet(x,is_training):
    x = tf.reshape(x,shape=[-1,28,28,1])

    net = tf.layers.conv2d(x,32,5,activation=tf.nn.relu)
    net = tf.layers.max_pooling2d(net,2,2)
    net = tf.layers.conv2d(net,64,3,activation=tf.nn.relu)
    net = tf.layers.max_pooling2d(net,2,2)
    net = tf.contrib.layers.flatten(net)
    net = tf.layers.dense(net,1024)
    net = tf.layers.dropout(net,rate=0.4,training=is_training)

    return tf.layers.dense(net,10)

def model_fn(features,labels,mode,params):
    predict = lenet(
        features['image'],mode==tf.estimator.ModeKeys.TRAIN)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={'result':tf.argmax(predict,1)})

    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=predict,labels=labels))

    optimizer = tf.train.GradientDescentOptimizer(
        learning_rate=params['learning_rate'])

    train_op = optimizer.minimize(loss=loss,global_step = tf.train.get_global_step())

    eval_metric_ops = {
        'my_metric':tf.metrics.accuracy(
            tf.argmax(predict,1),labels)
    }

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops = eval_metric_ops)

mnist = input_data.read_data_sets('./mnist_data',one_hot=False)

model_params = {'learning_rate':0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn,params=model_params)

test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'image':mnist.train.images},
    y=mnist.train.labels.astype(np.int32),
    num_epochs=1,
    batch_size=128,
    shuffle=False)

test_results = estimator.evaluate(input_fn=test_input_fn)

accuracy_score = test_results['my_metric']
print('\nTest accuracy :%g %%'%(accuracy_score*100))

predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'image':mnist.test.images[:10]},
    num_epochs=1,
    shuffle=False)

predictions = estimator.predict(input_fn=predict_input_fn)

for i,p in enumerate(predictions):
    print('prediction %s:%s'%(i+1,p['result']))

 

猜你喜欢

转载自blog.csdn.net/Strive_For_Future/article/details/81564208