【tensorflow】数字识别— softmax 回归

soft回归,图片来自 tensorflow官网
  提到多分类任务,立马会想到使用 softmax 回归,这篇文章主要讲述在 tensorflow 平台使用 softmax 回归方法对 mnist 数据进行数字识别。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 读取数据
mnist_data = input_data.read_data_sets('../data/mnist', one_hot=True)

X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='X')
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='y_')

#  创建 W, b 变量
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 执行 softmax 函数
y = tf.nn.softmax(tf.add(tf.matmul(X, W), b))

#求损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

#梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(cross_entropy)

# 正确预测数目
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

# 准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(2000):
    xs, ys = mnist_data.train.next_batch(64)
    sess.run(train_step, feed_dict={X: xs, y_: ys})
    if i % 100 == 0:
        print(i, ' accuracy is: ', sess.run(accuracy, feed_dict={X: xs, y_: ys}))

print(sess.run(accuracy, feed_dict={X: mnist_data.test.images, y_: mnist_data.test.labels}))

  算法预测准确率在92%左右

猜你喜欢

转载自blog.csdn.net/lionel_fengj/article/details/80487482