MNIST是深度学习的入门demo,由6万张训练图片和1万张测试图片构成(数据集下载地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),每张图片都是28*28大小,而且都是黑白两色,这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,下面为训练及效果评估代码。
import sys from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np import matplotlib.pyplot as plt from pylab import mpl #设置plt显示中文字体,避免乱码 mpl.rcParams['font.sans-serif']=['Microsoft YaHei'] mpl.rcParams['axes.unicode_minus'] = False #读取训练集 mnist = input_data.read_data_sets("d:/share/MNIST_data/", one_hot=True) # x是特征值,1X784的一维向量 x = tf.placeholder(tf.float32, [None, 784]) # w表示每一个特征值(像素点)会影响结果的权重 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) #y是预测值,包含10个元素的数组 y = tf.matmul(x, W) + b # y_是图片实际对应的值,包含10个元素的0/1数组,1代表对应的index数字 y_ = tf.placeholder(tf.float32, [None, 10] ) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # mnist.train 训练数据,每次提取100张,循环6000次 for _ in range(6000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # 取得y的最大概率对应的数组索引来和y_的数组索引对比,如果索引相同,则表示预测正确 correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
用测试集来评估准确性
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
逐个测试图片查看其预测结果,将预测不准确的结果统计输出 pre_act=[] for i in range(0, len(mnist.test.images)): result = sess.run(correct_prediction, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])}) if not result: pre=sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])}) pre_list=max(pre.tolist()) m=max(pre_list) pre_value=pre_list.index(m) # print('预测的值是:', pre_value) actual=sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])}) actual_list=max(actual.tolist()) actual_value=actual_list.index(1) #将预测值和实际值组合添加到数组保存 pa='预测的值是:'+str(pre_value)+","+'实际的值是:'+str(actual_value) pre_act.append(pa) # print('实际的值是:', actual_value) display='预测的值是:'+str(pre_value)+'实际的值是:'+str(actual_value) #显示预测错误图片 # one_pic_arr = np.reshape(mnist.test.images [i], (28, 28)) # pic_matrix = np.matrix(one_pic_arr, dtype="float") # plt.imshow(pic_matrix) # plt.title(display) # plt.savefig('pic_matrix') # plt.show() # break #打印数组查看哪些测试图片预测错误及其真实值 print(pre_act) print("预测错误数量为:"+str(len(pre_act))+"测试数据集为:"+str(len(mnist.test.images)))
可以看到准确率只有0.9257,使用CNN准确率可以达到0.97以上。