Tensorflow-MNIST代码解析

MNIST是深度学习的入门demo,由6万张训练图片和1万张测试图片构成(数据集下载地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),每张图片都是28*28大小,而且都是黑白两色,这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,下面为训练及效果评估代码。

import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import matplotlib.pyplot as plt
from pylab import mpl
#设置plt显示中文字体,避免乱码
mpl.rcParams['font.sans-serif']=['Microsoft YaHei']
mpl.rcParams['axes.unicode_minus'] = False
#读取训练集
mnist = input_data.read_data_sets("d:/share/MNIST_data/", one_hot=True)
# x是特征值,1X784的一维向量
x = tf.placeholder(tf.float32, [None, 784])
# w表示每一个特征值(像素点)会影响结果的权重
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
#y是预测值,包含10个元素的数组
y = tf.matmul(x, W) + b
# y_是图片实际对应的值,包含10个元素的0/1数组,1代表对应的index数字
y_ = tf.placeholder(tf.float32, [None, 10] )
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# mnist.train 训练数据,每次提取100张,循环6000次
for _ in range(6000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# 取得y的最大概率对应的数组索引来和y_的数组索引对比,如果索引相同,则表示预测正确
correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

用测试集来评估准确性

 print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                     y_: mnist.test.labels}))
逐个测试图片查看其预测结果,将预测不准确的结果统计输出
pre_act=[]
for i in range(0, len(mnist.test.images)):
    result = sess.run(correct_prediction,
                      feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
    if not result:
        pre=sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
        pre_list=max(pre.tolist())
        m=max(pre_list)
        pre_value=pre_list.index(m)
        # print('预测的值是:', pre_value)
        actual=sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
        actual_list=max(actual.tolist())
        actual_value=actual_list.index(1)
        #将预测值和实际值组合添加到数组保存
        pa='预测的值是:'+str(pre_value)+","+'实际的值是:'+str(actual_value)
        pre_act.append(pa)
        # print('实际的值是:', actual_value)
        display='预测的值是:'+str(pre_value)+'实际的值是:'+str(actual_value)
        #显示预测错误图片
        # one_pic_arr = np.reshape(mnist.test.images [i], (28, 28))
        # pic_matrix = np.matrix(one_pic_arr, dtype="float")
        # plt.imshow(pic_matrix)
        # plt.title(display)
        # plt.savefig('pic_matrix')
        # plt.show()
        # break
#打印数组查看哪些测试图片预测错误及其真实值
print(pre_act)
print("预测错误数量为:"+str(len(pre_act))+"测试数据集为:"+str(len(mnist.test.images)))

可以看到准确率只有0.9257,使用CNN准确率可以达到0.97以上。

发布了123 篇原创文章 · 获赞 12 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/haiziccc/article/details/102378549