TensorFlow(七)--自己设计神经网络实现手写数字识别,准确率达0.98

1.设计的网络包含两个隐藏层,分别有500和300个神经元。
2.优化器采用Adam
3.学习率为0.001*(0.95**epoch),随着迭代次数的增加而减小

话不多说,直接上代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
#Dropout
keep_prob = tf.placeholder(tf.float32)
#学习率
lr = tf.Variable(0.001, dtype=tf.float32)

#创建神经网络
W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1))
b1 = tf.Variable(tf.zeros([500])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob)

W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1))
b2 = tf.Variable(tf.zeros([300])+0.1)
L2= tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2drop = tf.nn.dropout(L2,keep_prob)

W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2drop,W3)+b3)

#交叉熵损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#训练
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中
#argmax返回一维张量中最大值所在的位置
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(51):
        sess.run(tf.assign(lr,0.001*(0.95**epoch)))
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys,keep_prob:0.7})
        
        learning_rate = sess.run(lr)
        acc = sess.run(accuracy,feed_dict = {x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        print("Iter"+str(epoch)+",Testing Accuracy"+str(acc)+",Learning Rate"+str(learning_rate))

运行结果

Iter0,Testing Accuracy0.9347,Learning Rate0.001
Iter1,Testing Accuracy0.9502,Learning Rate0.00095
Iter2,Testing Accuracy0.9541,Learning Rate0.0009025
Iter3,Testing Accuracy0.9602,Learning Rate0.000857375
Iter4,Testing Accuracy0.9634,Learning Rate0.00081450626
Iter5,Testing Accuracy0.9653,Learning Rate0.0007737809
Iter6,Testing Accuracy0.9683,Learning Rate0.0007350919
Iter7,Testing Accuracy0.9695,Learning Rate0.0006983373
Iter8,Testing Accuracy0.9706,Learning Rate0.0006634204
Iter9,Testing Accuracy0.9734,Learning Rate0.0006302494
Iter10,Testing Accuracy0.9728,Learning Rate0.0005987369
Iter11,Testing Accuracy0.9735,Learning Rate0.0005688001
Iter12,Testing Accuracy0.9749,Learning Rate0.0005403601
Iter13,Testing Accuracy0.9743,Learning Rate0.0005133421
Iter14,Testing Accuracy0.9738,Learning Rate0.000487675
Iter15,Testing Accuracy0.9745,Learning Rate0.00046329122
Iter16,Testing Accuracy0.9767,Learning Rate0.00044012666
Iter17,Testing Accuracy0.9756,Learning Rate0.00041812033
Iter18,Testing Accuracy0.9756,Learning Rate0.00039721432
Iter19,Testing Accuracy0.9769,Learning Rate0.0003773536
Iter20,Testing Accuracy0.9769,Learning Rate0.00035848594
Iter21,Testing Accuracy0.9764,Learning Rate0.00034056162
Iter22,Testing Accuracy0.9761,Learning Rate0.00032353355
Iter23,Testing Accuracy0.9767,Learning Rate0.00030735688
Iter24,Testing Accuracy0.9775,Learning Rate0.000291989
Iter25,Testing Accuracy0.9767,Learning Rate0.00027738957
Iter26,Testing Accuracy0.9774,Learning Rate0.0002635201
Iter27,Testing Accuracy0.9785,Learning Rate0.00025034408
Iter28,Testing Accuracy0.9784,Learning Rate0.00023782688
Iter29,Testing Accuracy0.9785,Learning Rate0.00022593554
Iter30,Testing Accuracy0.9794,Learning Rate0.00021463877
Iter31,Testing Accuracy0.979,Learning Rate0.00020390682
Iter32,Testing Accuracy0.9789,Learning Rate0.00019371149
Iter33,Testing Accuracy0.9779,Learning Rate0.0001840259
Iter34,Testing Accuracy0.9797,Learning Rate0.00017482461
Iter35,Testing Accuracy0.979,Learning Rate0.00016608338
Iter36,Testing Accuracy0.9795,Learning Rate0.00015777921
Iter37,Testing Accuracy0.9798,Learning Rate0.00014989026
Iter38,Testing Accuracy0.9795,Learning Rate0.00014239574
Iter39,Testing Accuracy0.9805,Learning Rate0.00013527596
Iter40,Testing Accuracy0.9796,Learning Rate0.00012851215
Iter41,Testing Accuracy0.9803,Learning Rate0.00012208655
Iter42,Testing Accuracy0.9793,Learning Rate0.00011598222
Iter43,Testing Accuracy0.9797,Learning Rate0.00011018311
Iter44,Testing Accuracy0.9806,Learning Rate0.000104673956
Iter45,Testing Accuracy0.9808,Learning Rate9.944026e-05
Iter46,Testing Accuracy0.9807,Learning Rate9.446825e-05
Iter47,Testing Accuracy0.9799,Learning Rate8.974483e-05
Iter48,Testing Accuracy0.9811,Learning Rate8.525759e-05
Iter49,Testing Accuracy0.9807,Learning Rate8.099471e-05
Iter50,Testing Accuracy0.9803,Learning Rate7.6944976e-05
发布了43 篇原创文章 · 获赞 13 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/star_of_science/article/details/104276526