CS231n(一)Image Classification Notes

图像分类问题:就是已有固定的分类标签集合,然后对于输入的图像,从分类的标签集合中找到一个分类标签,最后把分类标签分配给该输入图像。也就是给定一个图像,预测它属于的那个分类标签。

Nearest Neighbir 分类器

需衡量样本间的距离,就是将两张图片先转化为两个向量I1和I2,然后计算他们的距离(针对所有的像素)。

用L1距离作为例子:

L1距离(曼哈顿距离):

L2距离(欧式距离):

k-Nearest Neighbor分类器

找最相似的k个图片的标签,然后让他们针对测试图片进行投票,最后把票数最高的标签作为对测试图片的预测。所以当k=1的时候,k-Nearest Neighbor分类器就是Nearest Neighbor分类器。

N折交叉验证:将训练集平均分成5份,其中4份用来训练,1份用来验证。然后我们循环着取其中4份来训练,其中1份来验证,最后取所有5次验证结果的平均值作为算法验证结果。

NN分类器的优点:易于理解实现简单。

                   缺点:要记录全部的训练数据,算法测试速度超级慢,过于依赖特征(图像特征不好找)因而准确率不高。

Tensorflow实现KNN算法:(k=1)

这里10000条数据的784个像素分别减去200条测试样本的每一条(一次循环导入一条)

[[1,...,784],

[1,...,784],

.

.

.

[1,...,784]]

import numpy as np
import tensorflow as tf

#load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

#选择1000条候选样本,200条测试样本,测试样本跟候选样本比较得到最近的K个样本,则k个样本的标签大多数为某一类,测试样本就为哪一类。

Xtr, Ytr = mnist.train.next_batch(10000)
Xte, Yte = mnist.train.next_batch(200)

#TF Graph Input 占位符 用来feed数据
xtr = tf.placeholder(tf.float32, [None,784])
xte = tf.placeholder(tf.float32, [784])

#最近邻使用L1距离(曼哈顿距离)
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

#预测
pred = tf.arg_min(distance, 0)

accuracy = 0

#初始化
sess = tf.Session()
                               
sess.run(tf.global_variables_initializer())

#
for i in range(len(Xte)):
    # Get nearest neighbor
    nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})  #每次循环feed数据,候选Xtr全部,测试集Xte一次循环输入一条
    # 获得与测试样本最近样本的类别,计算与真实类别的误差
    if i % 10 ==0:
        print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i]))
    # 计算误差率
    if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
        accuracy += 1. / len(Xte)
print('Done!')
print("Accuracy:", accuracy)

结果:

Test 0 Prediction: 8 True Class: 8

Test 10 Prediction: 0 True Class: 0

Test 20 Prediction: 4 True Class: 4

Test 30 Prediction: 3 True Class: 3

Test 40 Prediction: 2 True Class: 2

Test 50 Prediction: 4 True Class: 4

Test 60 Prediction: 3 True Class: 3

Test 70 Prediction: 8 True Class: 8

Test 80 Prediction: 1 True Class: 1

Test 90 Prediction: 5 True Class: 5

Test 100 Prediction: 6 True Class: 6

Test 110 Prediction: 2 True Class: 2

Test 120 Prediction: 3 True Class: 3

Test 130 Prediction: 5 True Class: 5

Test 140 Prediction: 8 True Class: 8

Test 150 Prediction: 8 True Class: 8

Test 160 Prediction: 4 True Class: 4

Test 170 Prediction: 2 True Class: 2

Test 180 Prediction: 6 True Class: 6

Test 190 Prediction: 8 True Class: 8

Done!

Accuracy: 0.9400000000000007

猜你喜欢

转载自blog.csdn.net/qq_40755643/article/details/82831612