TensorFlow(八) TensorFlow图像识别(KNN)

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import  datasets
import random
from PIL import Image

from tensorflow.examples.tutorials.mnist import  input_data

sess=tf.Session()
mnist= input_data.read_data_sets("MNIST_data/",one_hot=True)
#本例包含10个类别
train_size=1000
test_size=102
rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False)

rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False)

x_vals_train=mnist.train.images[rand_train_indices]
x_vals_test=mnist.train.images[rand_test_indices]
y_vals_train=mnist.train.labels[rand_train_indices]
y_vals_test=mnist.train.labels[rand_test_indices]

k=4
batch_size=6
x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)

#L1距离 shape=(6, 1000)   sub.shape=(1000,784) - (6,1,10)=(6,1000,784)
distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)

#top K (6, 4)
top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k)
#(6, 4, 10)  =   gather((1000,10),(6,4)  )
prediction_indices=tf.gather(y_target_train,top_k_indices)
#shape=(6, 10)
count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1)
#预测模型 shape=(6,)
prediction=tf.arg_max(count_of_prediction,dimension=1)

num_loop=int(np.ceil(len(x_vals_test)/batch_size))
test_output=[]
actual_vals=[]
for i in range(num_loop):
    min_index=i*batch_size
    max_index=min((i+1)*batch_size,len(x_vals_test))
    #获取数据
    x_batch=x_vals_test[min_index:max_index]
    y_batch = y_vals_test[min_index:max_index]
    predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train})
    test_output.extend(predictions)
    actual_vals.extend(np.argmax(y_batch,axis=1))

#精确度预测
accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i])
print("Accuarcy: "+str(accuracy))

actuals=np.argmax(y_batch,axis=1)
for i in range(len(actuals)):
    plt.subplot(2,3,i+1)
    plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r")
    plt.title('Actual: '+str(actuals[i])+' Pred:'+str(predictions[i]),fontsize=10)
    frame=plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)

plt.show()

猜你喜欢

转载自www.cnblogs.com/x0216u/p/9241759.html