TensorFlow-mnist训练与测试

版权声明:本文为博主原创文章,欢迎转载。 https://blog.csdn.net/samylee/article/details/83860259

TensorFlow介绍性的概念就不阐述了,但是直接上代码好像又比较突兀!所以提前祝小伙伴们春节快乐

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、tensorflow-gpu-1.4.0

好了,上代码!代码通过分步解析,浅显易懂!

第一步:导入tensorflow

import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf

第二步:设计权重与偏置

#############################define weights and bias########################
w_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev = 0.1))
b_conv1 = tf.Variable(tf.constant(0.1, shape = [32]))

w_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev = 0.1))
b_conv2 = tf.Variable(tf.constant(0.1, shape = [64]))

w_fc1 = tf.Variable(tf.truncated_normal([7*7*64, 1024], stddev = 0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape = [1024]))

w_fc2 = tf.Variable(tf.truncated_normal([1*1*1024, 10], stddev = 0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape = [10]))
#############################################################################

第三步:设计网络架构

##############################define model###################################
#define input size
x_ = tf.placeholder(tf.float32, [None, 28*28])
y_ = tf.placeholder(tf.float32, [None, 10])

#reshape input data
x_input = tf.reshape(x_, [-1, 28, 28, 1])

#conv1
conv1 = tf.nn.conv2d(x_input, w_conv1, strides=[1,1,1,1], padding='SAME') + b_conv1
relu1 = tf.nn.relu(conv1)

#pool1
pool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME')

#conv2
conv2 = tf.nn.conv2d(pool1, w_conv2, strides=[1,1,1,1], padding='SAME') + b_conv2
relu2 = tf.nn.relu(conv2)

#pool2
pool2 = tf.nn.max_pool(relu2, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME')

#reshape pool2 for fc1
pool2_reshape = tf.reshape(pool2, [-1, 7*7*64])

#fc1
fc1 = tf.matmul(pool2_reshape, w_fc1) + b_fc1
relu3 = tf.nn.relu(fc1)

#dropout
keep_prob = tf.placeholder(tf.float32)
fc1_dropout = tf.nn.dropout(relu3, keep_prob)

#fc2
fc2 = tf.matmul(fc1_dropout, w_fc2) + b_fc2

#softmax
y_out = tf.nn.softmax(fc2)

#define loss
cross_entropy_loss = -tf.reduce_sum(y_ * tf.log(y_out))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_loss)

#define accuracy
correct_prediction = tf.equal(tf.argmax(y_out,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
###############################################################################

第四步:导入数据,训练网络与测试,测试准确率96.5%

####################################start model################################
#load data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)

#define sess
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#define save model
saver = tf.train.Saver(max_to_keep=3)

#train
for step in range(1000):
    batch = mnist.train.next_batch(50)
    if step % 100 == 0:
        train_accuracy = accuracy.eval(session=sess, feed_dict={x_:batch[0], y_:batch[1], keep_prob:1.0})
        print("step %d, train_accuracy %g" %(step, train_accuracy))
        saver.save(sess,"MNIST_model/model.ckpt-" + str(step))
    train_step.run(session=sess, feed_dict={x_:batch[0], y_:batch[1], keep_prob:0.5})

#test
print("test accuracy %g" %accuracy.eval(session=sess, feed_dict={x_:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

第五步:结束语

代码有部分注释,可以根据注释进行理解,若有不理解的地方,可以百度或者谷歌一下。

打完收工!

任何问题请加唯一QQ2258205918(名称samylee)!

猜你喜欢

转载自blog.csdn.net/samylee/article/details/83860259