从零开始 TensorFlow NN

L1距离,在训练集中找出离测试数据最近的数据,比较他们的标签

from __future__ import print_function
import numpy as np
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)

Xtrain ,Ytrain = mnist.train.next_batch(5000)
Xtest, Ytest = mnist.test.next_batch(200)

xtrain=tf.placeholder('float',[None,784])
xtest=tf.placeholder('float',[784])
#L1 distance
distance=tf.reduce_sum(tf.abs(tf.add(xtrain,tf.negative(xtest))),reduction_indices=2)
#返回最小下标
predict=tf.argmin(distance,0)

accuracy=0.
init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for i in range(len(Xtest)):
        nn_index=sess.run(predict,feed_dict={xtrain:Xtrain,xtest:Xtest[i,:]})
        print('第', i, '次预测', np.argmax(Ytrain[nn_index]), '真正的类别:', np.argmax(Ytest[i]))
        if np.argmax(Ytrain[nn_index])==np.argmax(Ytest[i]):
            accuracy+=1./len(Xtest)
    print('accuracy:',accuracy)

猜你喜欢

转载自blog.csdn.net/Neekity/article/details/85223457
今日推荐