代码:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 #当前路径 mnist = input_data.read_data_sets("MNISt_data", one_hot=True)
运行结果:
Extracting MNISt_data/train-images-idx3-ubyte.gz Extracting MNISt_data/train-labels-idx1-ubyte.gz Extracting MNISt_data/t10k-images-idx3-ubyte.gz Extracting MNISt_data/t10k-labels-idx1-ubyte.gz
代码:
#每个批次的大小 #以矩阵的形式放进去 batch_size = 100 #计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size #定义三个placeholder #28 x 28 = 784 x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) keep_prob = tf.placeholder(tf.float32) #学习率 lr = tf.Variable(0.001, dtype=tf.float32) #创建一个的神经网络 #输入层784,隐藏层一500,隐藏层二300,输出层10个神经元 #隐藏层 W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1)) b1 = tf.Variable(tf.zeros([500]) + 0.1) L1 = tf.nn.tanh(tf.matmul(x, W1) + b1) L1_drop = tf.nn.dropout(L1,keep_prob) W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1)) b2 = tf.Variable(tf.zeros([300]) + 0.1) L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2) L2_drop = tf.nn.dropout(L2,keep_prob) W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1)) b3 = tf.Variable(tf.zeros([10]) + 0.1) prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3) #交叉熵代价函数 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)) #训练 train_step = tf.train.AdamOptimizer(lr).minimize(loss) #初始化变量 init = tf.global_variables_initializer() #结果存放在一个布尔型列表中 #tf.argmax(y, 1)与tf.argmax(prediction, 1)相同返回True,不同则返回False #argmax返回一维张量中最大的值所在的位置 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) #求准确率 #tf.cast(correct_prediction, tf.float32) 将布尔型转换为浮点型 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(init) #总共51个周期 for epoch in range(51): #刚开始学习率比较大,后来慢慢变小 sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch))) #总共n_batch个批次 for batch in range(n_batch): #获得一个批次 batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys, keep_prob:1.0}) learning_rate = sess.run(lr) #训练完一个周期后测试数据准确率 acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0}) print("Iter" + str(epoch) + ", Testing Accuracy" + str(acc)+ ", Learning_rate" + str(learning_rate))
运行结果:
Iter0, Testing Accuracy0.9527, Learning_rate0.001 Iter1, Testing Accuracy0.9634, Learning_rate0.00095 Iter2, Testing Accuracy0.9692, Learning_rate0.0009025 Iter3, Testing Accuracy0.9711, Learning_rate0.000857375 Iter4, Testing Accuracy0.9716, Learning_rate0.000814506 Iter5, Testing Accuracy0.9757, Learning_rate0.000773781 Iter6, Testing Accuracy0.9737, Learning_rate0.000735092 Iter7, Testing Accuracy0.9779, Learning_rate0.000698337 Iter8, Testing Accuracy0.9776, Learning_rate0.00066342 Iter9, Testing Accuracy0.9778, Learning_rate0.000630249 Iter10, Testing Accuracy0.9764, Learning_rate0.000598737 Iter11, Testing Accuracy0.9787, Learning_rate0.0005688 Iter12, Testing Accuracy0.9776, Learning_rate0.00054036 Iter13, Testing Accuracy0.9779, Learning_rate0.000513342 Iter14, Testing Accuracy0.9805, Learning_rate0.000487675 Iter15, Testing Accuracy0.9786, Learning_rate0.000463291 Iter16, Testing Accuracy0.9786, Learning_rate0.000440127 Iter17, Testing Accuracy0.9806, Learning_rate0.00041812 Iter18, Testing Accuracy0.9768, Learning_rate0.000397214 Iter19, Testing Accuracy0.9789, Learning_rate0.000377354 Iter20, Testing Accuracy0.9793, Learning_rate0.000358486 Iter21, Testing Accuracy0.9802, Learning_rate0.000340562 Iter22, Testing Accuracy0.9783, Learning_rate0.000323534 Iter23, Testing Accuracy0.9802, Learning_rate0.000307357 Iter24, Testing Accuracy0.9807, Learning_rate0.000291989 Iter25, Testing Accuracy0.9804, Learning_rate0.00027739 Iter26, Testing Accuracy0.9801, Learning_rate0.00026352 Iter27, Testing Accuracy0.9776, Learning_rate0.000250344 Iter28, Testing Accuracy0.9804, Learning_rate0.000237827 Iter29, Testing Accuracy0.9807, Learning_rate0.000225936 Iter30, Testing Accuracy0.9804, Learning_rate0.000214639 Iter31, Testing Accuracy0.9805, Learning_rate0.000203907 Iter32, Testing Accuracy0.9794, Learning_rate0.000193711 Iter33, Testing Accuracy0.9806, Learning_rate0.000184026 Iter34, Testing Accuracy0.979, Learning_rate0.000174825 Iter35, Testing Accuracy0.9811, Learning_rate0.000166083 Iter36, Testing Accuracy0.9807, Learning_rate0.000157779 Iter37, Testing Accuracy0.9809, Learning_rate0.00014989 Iter38, Testing Accuracy0.9812, Learning_rate0.000142396 Iter39, Testing Accuracy0.9806, Learning_rate0.000135276 Iter40, Testing Accuracy0.9809, Learning_rate0.000128512 Iter41, Testing Accuracy0.9812, Learning_rate0.000122087 Iter42, Testing Accuracy0.9812, Learning_rate0.000115982 Iter43, Testing Accuracy0.9814, Learning_rate0.000110183 Iter44, Testing Accuracy0.9801, Learning_rate0.000104674 Iter45, Testing Accuracy0.9812, Learning_rate9.94403e-05 Iter46, Testing Accuracy0.9814, Learning_rate9.44682e-05 Iter47, Testing Accuracy0.9807, Learning_rate8.97448e-05 Iter48, Testing Accuracy0.9817, Learning_rate8.52576e-05 Iter49, Testing Accuracy0.9811, Learning_rate8.09947e-05 Iter50, Testing Accuracy0.9812, Learning_rate7.6945e-05