KNN-mnist数据集识别

版权声明:如要转载请标记出处,谢谢合作! https://blog.csdn.net/CQDIY/article/details/88116681

KNN-mnist数据集识别

win10
python3.6
tensorflow1.12

import numpy as np
import tensorflow as tf

# 加载 mnist 数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot=True)

# 限制数据集数目
X_train, Y_train = mnist.train.next_batch(5000)
X_test, Y_test = mnist.test.next_batch(200)

# 输入训练数据
x_train = tf.placeholder("float", [None, 784])
x_test = tf.placeholder("float",[784])

# KNN 的 L!曼哈顿距离计算
distance = tf.reduce_sum(tf.abs(tf.add(x_train, tf.negative(x_test))), reduction_indices=1)

# 最小距离紧邻的预测
pred = tf.arg_min(distance, 0)

accuracy = 0.0

# 参数初始化
init = tf.global_variables_initializer()
# 开始训练
with tf.Session() as sess:
    sess.run(init)
    
    # 循环所有的测试数据
    for i in range(len(X_test)):
        nn_index = sess.run(pred, feed_dict={x_train: X_train, x_test: X_test[i, :]})
        print("Test", i, "Prediction:", np.argmax(Y_train[nn_index]),"True Class:", np.argmax(Y_test[i]))
        if np.argmax(Y_train[nn_index]) == np.argmax(Y_test[i]):
            accuracy += 1./len(X_test)
    print("Done!")
    print("Accuracy:", accuracy)

测试结果:
Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
Test 0 Prediction: 6 True Class: 6
Test 1 Prediction: 1 True Class: 1
Test 2 Prediction: 3 True Class: 3
Test 3 Prediction: 8 True Class: 8
Test 4 Prediction: 0 True Class: 0
Test 5 Prediction: 2 True Class: 2
Test 6 Prediction: 5 True Class: 5
Test 7 Prediction: 1 True Class: 1
Test 8 Prediction: 0 True Class: 0
Test 9 Prediction: 9 True Class: 4
Test 10 Prediction: 1 True Class: 1

Test 199 Prediction: 0 True Class: 0
Done!
Accuracy: 0.9250000000000007

猜你喜欢

转载自blog.csdn.net/CQDIY/article/details/88116681
今日推荐