TensorFlow学习(2)简单手写数字识别

MNIST数据集

  • 下载网址:MNIST
  • 下载的数据:训练集的图片和标签,测试集的图片和标签,60000行的训练数据集和10000行的测试数据集

将二维数组展开成一维的向量

构建简单的神经网络

softmax函数

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

# 不是一张张图片放入神经网络,定义一个批次,一次 100
batch_size = 100
# 计算一个有多少批次,整除
n_batch = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 创建一个简单的神经网络
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(x,W)+b)

# 二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
# 梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()

# 结果存放在布尔型列表中
# tf.equal 相等返回 True,否则 False,argmax 比较 y 中哪个元素的值为 1,返回该元素下标
correct_predition = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 求准确率
# tf.cast 将布尔型转换为32位浮点型,True -> 1.0,False -> 0.0,然后求平均值,如有 9 个 1,1 个 0,平均值为 0.9,准确率为 0.9
accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32))

with tf.Session() as sess:
    sess.run(init)
    # 循环 21 个周期,每个周期批次为 100,每个周期将所有图片都训练一次
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
        #训练完一个周期看下准确率
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('Iter ' + str(epoch) + ', Testing Accuracy' + str(acc))


Iter 0, Testing Accuracy0.8321
Iter 1, Testing Accuracy0.8712
Iter 2, Testing Accuracy0.8819
Iter 3, Testing Accuracy0.8879
Iter 4, Testing Accuracy0.8941
Iter 5, Testing Accuracy0.8972
Iter 6, Testing Accuracy0.9
Iter 7, Testing Accuracy0.9016
Iter 8, Testing Accuracy0.9043
Iter 9, Testing Accuracy0.905
Iter 10, Testing Accuracy0.9061
Iter 11, Testing Accuracy0.9073
Iter 12, Testing Accuracy0.9087
Iter 13, Testing Accuracy0.9099
Iter 14, Testing Accuracy0.9099
Iter 15, Testing Accuracy0.911
Iter 16, Testing Accuracy0.9117
Iter 17, Testing Accuracy0.9122
Iter 18, Testing Accuracy0.9134
Iter 19, Testing Accuracy0.9143
Iter 20, Testing Accuracy0.9137

优化

  • 修改批次大小 batch_size
  • 修改训练周期次数
  • 增加隐藏层和修改激活函数
  • 修改权重和偏置初始化策略
  • 修改代价函数
  • 修改训练方法和学习率

猜你喜欢

转载自blog.csdn.net/HAIYUANBOY/article/details/89838541
今日推荐