版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/slz0813/article/details/78832908
TensorFlow下运用自带input_data接口图取数据,抽样单个数据,1*784 reshape为 28*28 ,最后可视化显示
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
with tf.Session() as sess:
sess.run(init)
for epoch in range(1):
for batch in range(n_batch):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
#forward detection
test_x,test_y=mnist.train.next_batch(1)
print("test_x_shape=",x.shape,",test_y_shape=",y.shape)
result_value=sess.run(prediction,feed_dict={x:test_x})
img=test_x.reshape(28,28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.imshow(img,cmap='gray')
plt.show()
truth=tf.argmax(test_y,1)
dec=tf.argmax(result_value,1)
print("dec=",sess.run(dec),",origin=",sess.run(truth))