Tensorflow—继续优化,使MNIST准确率98%以上

代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


#载入数据集
#当前路径
mnist = input_data.read_data_sets("MNISt_data", one_hot=True)

运行结果:

Extracting MNISt_data/train-images-idx3-ubyte.gz
Extracting MNISt_data/train-labels-idx1-ubyte.gz
Extracting MNISt_data/t10k-images-idx3-ubyte.gz
Extracting MNISt_data/t10k-labels-idx1-ubyte.gz

代码:

#每个批次的大小
#以矩阵的形式放进去
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size


#定义三个placeholder
#28 x 28 = 784
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
#学习率
lr = tf.Variable(0.001, dtype=tf.float32)


#创建一个的神经网络
#输入层784,隐藏层一500,隐藏层二300,输出层10个神经元
#隐藏层
W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
b1 = tf.Variable(tf.zeros([500]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1,keep_prob)


W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1))
b2 = tf.Variable(tf.zeros([300]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2,keep_prob)



W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3)

#交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))

#训练
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()



#结果存放在一个布尔型列表中
#tf.argmax(y, 1)与tf.argmax(prediction, 1)相同返回True,不同则返回False
#argmax返回一维张量中最大的值所在的位置
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

#求准确率
#tf.cast(correct_prediction, tf.float32) 将布尔型转换为浮点型
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


with tf.Session() as sess:
    sess.run(init)
    #总共51个周期
    for epoch in range(51):
        #刚开始学习率比较大,后来慢慢变小
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))
        #总共n_batch个批次
        for batch in range(n_batch):
            #获得一个批次
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys, keep_prob:1.0})
        
        learning_rate = sess.run(lr)
        #训练完一个周期后测试数据准确率
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0})
        
        print("Iter" + str(epoch) + ", Testing Accuracy" + str(acc)+ ", Learning_rate" + str(learning_rate))

运行结果:

Iter0, Testing Accuracy0.9527, Learning_rate0.001
Iter1, Testing Accuracy0.9634, Learning_rate0.00095
Iter2, Testing Accuracy0.9692, Learning_rate0.0009025
Iter3, Testing Accuracy0.9711, Learning_rate0.000857375
Iter4, Testing Accuracy0.9716, Learning_rate0.000814506
Iter5, Testing Accuracy0.9757, Learning_rate0.000773781
Iter6, Testing Accuracy0.9737, Learning_rate0.000735092
Iter7, Testing Accuracy0.9779, Learning_rate0.000698337
Iter8, Testing Accuracy0.9776, Learning_rate0.00066342
Iter9, Testing Accuracy0.9778, Learning_rate0.000630249
Iter10, Testing Accuracy0.9764, Learning_rate0.000598737
Iter11, Testing Accuracy0.9787, Learning_rate0.0005688
Iter12, Testing Accuracy0.9776, Learning_rate0.00054036
Iter13, Testing Accuracy0.9779, Learning_rate0.000513342
Iter14, Testing Accuracy0.9805, Learning_rate0.000487675
Iter15, Testing Accuracy0.9786, Learning_rate0.000463291
Iter16, Testing Accuracy0.9786, Learning_rate0.000440127
Iter17, Testing Accuracy0.9806, Learning_rate0.00041812
Iter18, Testing Accuracy0.9768, Learning_rate0.000397214
Iter19, Testing Accuracy0.9789, Learning_rate0.000377354
Iter20, Testing Accuracy0.9793, Learning_rate0.000358486
Iter21, Testing Accuracy0.9802, Learning_rate0.000340562
Iter22, Testing Accuracy0.9783, Learning_rate0.000323534
Iter23, Testing Accuracy0.9802, Learning_rate0.000307357
Iter24, Testing Accuracy0.9807, Learning_rate0.000291989
Iter25, Testing Accuracy0.9804, Learning_rate0.00027739
Iter26, Testing Accuracy0.9801, Learning_rate0.00026352
Iter27, Testing Accuracy0.9776, Learning_rate0.000250344
Iter28, Testing Accuracy0.9804, Learning_rate0.000237827
Iter29, Testing Accuracy0.9807, Learning_rate0.000225936
Iter30, Testing Accuracy0.9804, Learning_rate0.000214639
Iter31, Testing Accuracy0.9805, Learning_rate0.000203907
Iter32, Testing Accuracy0.9794, Learning_rate0.000193711
Iter33, Testing Accuracy0.9806, Learning_rate0.000184026
Iter34, Testing Accuracy0.979, Learning_rate0.000174825
Iter35, Testing Accuracy0.9811, Learning_rate0.000166083
Iter36, Testing Accuracy0.9807, Learning_rate0.000157779
Iter37, Testing Accuracy0.9809, Learning_rate0.00014989
Iter38, Testing Accuracy0.9812, Learning_rate0.000142396
Iter39, Testing Accuracy0.9806, Learning_rate0.000135276
Iter40, Testing Accuracy0.9809, Learning_rate0.000128512
Iter41, Testing Accuracy0.9812, Learning_rate0.000122087
Iter42, Testing Accuracy0.9812, Learning_rate0.000115982
Iter43, Testing Accuracy0.9814, Learning_rate0.000110183
Iter44, Testing Accuracy0.9801, Learning_rate0.000104674
Iter45, Testing Accuracy0.9812, Learning_rate9.94403e-05
Iter46, Testing Accuracy0.9814, Learning_rate9.44682e-05
Iter47, Testing Accuracy0.9807, Learning_rate8.97448e-05
Iter48, Testing Accuracy0.9817, Learning_rate8.52576e-05
Iter49, Testing Accuracy0.9811, Learning_rate8.09947e-05
Iter50, Testing Accuracy0.9812, Learning_rate7.6945e-05

猜你喜欢

转载自blog.csdn.net/wangsiji_buaa/article/details/80205629