提到多分类任务,立马会想到使用 softmax 回归,这篇文章主要讲述在 tensorflow 平台使用 softmax 回归方法对 mnist 数据进行数字识别。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 读取数据
mnist_data = input_data.read_data_sets('../data/mnist', one_hot=True)
X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='X')
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='y_')
# 创建 W, b 变量
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 执行 softmax 函数
y = tf.nn.softmax(tf.add(tf.matmul(X, W), b))
#求损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
#梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(cross_entropy)
# 正确预测数目
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# 准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(2000):
xs, ys = mnist_data.train.next_batch(64)
sess.run(train_step, feed_dict={X: xs, y_: ys})
if i % 100 == 0:
print(i, ' accuracy is: ', sess.run(accuracy, feed_dict={X: xs, y_: ys}))
print(sess.run(accuracy, feed_dict={X: mnist_data.test.images, y_: mnist_data.test.labels}))
算法预测准确率在92%左右