简单的Tensorflow手写识别体

来源于《TensorFlow实战》, 黄文坚, 唐源

# -*- coding:utf-8 -*-
# [https://www.amazon.cn/dp/B06X8Z4BS9/ref=sr_1_1?ie=UTF8&qid=1550477736&sr=8-1&keywords=tensorflow%E5%AE%9E%E6%88%98](TensorFlow实战, 黄文坚, 唐源)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

import tensorflow as tf
sess = tf.InteractiveSession() # 交互式的session
x = tf.placeholder(tf.float32, [None, 784]) # 输入的特征占位

W = tf.Variable(tf.zeros([784, 10])) # tf变量(参数),可以求梯度和被更新
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b) # matmul矩阵乘法,‘+’中二维数组加上一维数组,执行广播功能,一维数组自动扩展成二维数组,即二维数组每一个行分量加上这个一维数组

y_ = tf.placeholder(tf.float32, [None, 10]) # 输入的真实标签占位
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) # reduce_sum按第二个维度对y_ * tf.log(y)得到的数组求和(因为这里y_是10维one-hot编码的向量,所以10次乘法中9个都是0*tf.log(y_i)),得到每个样本cross_entropy,之后tf.reduce_mean把这些样本的cross_entropy平均,即这里cross_entropy是一个平均

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 参数更新器

tf.global_variables_initializer().run() # 全局变量初始化(参数初始化),session是交互式的,所以直接.run()即可。否则,普通的session需要session.run(...)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys}) # 不断迭代运行参数更新器,优化参数

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 真实标签y_和预测输出y中最大值的索引是否一样,True or Fasle

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # tf.cast将 True or Fasle 转换成 1 or 0, 即预测对了和错了,然后求平均

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})) # eval评价

这里注意该网络为简单网络,参数初始化并不重要,可以全部初始化为0。但是对于复杂的网络,参数初始化非常重要,不能全部初始化为0(参见1参见2),一般初始化为正态分布或均匀分布,如果全部初始化为0,那么预测结果将都是标签数据的均值,训练集MSE下降到标签数据的方差后,无法继续下降,网络失效!

猜你喜欢

转载自blog.csdn.net/W_LAILAI/article/details/87636443
今日推荐