L1距离,在训练集中找出离测试数据最近的数据,比较他们的标签
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)
Xtrain ,Ytrain = mnist.train.next_batch(5000)
Xtest, Ytest = mnist.test.next_batch(200)
xtrain=tf.placeholder('float',[None,784])
xtest=tf.placeholder('float',[784])
distance=tf.reduce_sum(tf.abs(tf.add(xtrain,tf.negative(xtest))),reduction_indices=2)
predict=tf.argmin(distance,0)
accuracy=0.
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(len(Xtest)):
nn_index=sess.run(predict,feed_dict={xtrain:Xtrain,xtest:Xtest[i,:]})
print('第', i, '次预测', np.argmax(Ytrain[nn_index]), '真正的类别:', np.argmax(Ytest[i]))
if np.argmax(Ytrain[nn_index])==np.argmax(Ytest[i]):
accuracy+=1./len(Xtest)
print('accuracy:',accuracy)