TensorFlow学习记录2——基于softmax回归的分类算法

TensorFlow学习记录2——基于softmax回归的分类算法


主要参考博客
LightRNN:深度学习之以小见大

深入MNIST——

import tensorflow as tf
import numpy as np
import matplotlib.pylab as plt

# prepare mnist data
from tensorflow.examples.tutorials.mnist import input_data
MNIST_data_folder = "D:\pycharm\OCR_Test_SH\src\LightRNN\MNIST_data"
mnist = input_data.read_data_sets(MNIST_data_folder, one_hot=True)

im = mnist.train.images[1]
im=im.reshape(-1,28)
print('input:', mnist.train.images.shape)
# plt.imshow(im)
# plt.show()


sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())

y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

for i in range(1000):
  batch = mnist.train.next_batch(50)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
  print(i, accuracy.eval(feed_dict={x: batch[0], y_: batch[1]}))

print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
for i in range(0, len(mnist.test.images)):
  result_prediction = correct_prediction.eval(feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
  if not result_prediction:
    label_arr = sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
    predict_arr = sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
    label = tf.argmax(label_arr, 1)
    predict = tf.argmax(predict_arr, 1)
    print('the predict result of {} is error, the label is {} and the predict is {}'.format(i, label, predict))
    current_image_arr = np.reshape(mnist.test.images[i], (28, 28))
    current_image = np.matrix(current_image_arr, dtype="float")
    plt.imshow(current_image)
    plt.show()
    # break

待做:代码46行

猜你喜欢

转载自blog.csdn.net/JavenLau/article/details/90673061