基于CNN的MNIST手写体识别代码2

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(xs, [-1,28,28,1])

conv1 = tf.layers.conv2d(x_image,32,5,1,'same',activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(conv1,2,2,padding='same')
conv2 = tf.layers.conv2d(pool1,64,5,1,padding='same',activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(conv2,2,2,padding='same')
pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
fc1 = tf.layers.dense(pool2_flat, 1024)
output = tf.layers.dense(fc1,10, activation=tf.nn.softmax)

loss= tf.losses.softmax_cross_entropy(onehot_labels=ys,logits=output)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
for step in range(1000):
    xs_batch, ys_batch = mnist.train.next_batch(100)
    _, loss1 =sess.run([train_step,loss], feed_dict={xs: xs_batch, ys: ys_batch})
    if step%50 ==0:
        y_pre = sess.run(output, feed_dict={xs: mnist.test.images})
        correct_prediction = tf.equal(tf.arg_max(y_pre, 1), tf.arg_max(mnist.test.labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        result = sess.run(accuracy, feed_dict={xs: mnist.test.images, ys: mnist.test.labels})
        print(step, loss1,result)

运行结果:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0 2.3022413 0.0963
50 1.8999649 0.486
100 1.7091469 0.7803
150 1.5964711 0.8999
200 1.5469754 0.9139
250 1.559378 0.9313
300 1.5206326 0.9365
350 1.5044914 0.9405
400 1.487135 0.9472
450 1.5237399 0.951
500 1.5056982 0.9567
550 1.4996839 0.9596
600 1.4718686 0.9575
650 1.5171963 0.9636
700 1.5034649 0.963
750 1.5026108 0.9646
800 1.4924903 0.966
850 1.4993373 0.9684
900 1.5047902 0.967
950 1.4989574 0.9717

猜你喜欢

转载自blog.csdn.net/luckyboy101/article/details/83181861