机器学习方法原理及编程实现--01.K近邻法(实现MNIST数据)

文章列表
1.机器学习方法原理及编程实现–01.K近邻法(实现MNIST数据).
2.机器学习方法原理及编程实现–01. .
3.机器学习方法原理及编程实现–01. .
4.机器学习方法原理及编程实现–01. .
5.机器学习方法原理及编程实现–01..
6.机器学习方法原理及编程实现–01..

1 KNN

K近邻法(k-nearest neighbor,kNN)是一种基本的分类与回归方法。K近邻法的输入为实例的特征向量,对应于特征空间的点;输出为实例的类别,可以取多类。K近邻法没有显示的学习过程,在对测试数据进行分类时需根据距离度量对所有训练数据计算出k个最邻近的实例,再通过分类决策(如多数表决等)输出类别。3个加粗带下划线的部分为k近邻的3个基本要素。
Knn实际上是一种特征空间划分问题,但对不同的测试实例,特征空间是不同的。在特征空间中,对于每个测试实例,距离该点比其他点更近的所有点组成的区域叫做单元(cell),每个训练实例拥有一个单元,所有训练实例点的单元构成对特征空间的一个划分。

1.1 距离度量

特征空间中两个实例点的距离是两个实例点相似程度的反映。K近邻模型的特征空间一般是n维实数向量空间,常用 范数来衡量:

这里写图片描述

其中p≥1,当p=1时,称为曼哈顿距离;当p=2时,称为欧式距离;

1.2 K值的选择

K值较小时,输出结果会对邻近实例点比较敏感,如果实例点恰巧是噪声,输出会出错;k较大时,距离较远的点也会对输出结果产生影响。一般而言,k越大模型越简单,如果k为所有训练实例的个数,那么模型将直接选取训练数据中类别最多的类作为输出。

1.3 分类决策规则

多数表决用的比较多,相应的无分类率为:

这里写图片描述

式中的I为指示函数,及yi=ci时为1,否则为0。所有多数表决规则的误分类最小等价于经验风险最小化。k一般选取一个较小的值,通常采用交叉验证选择最优的k。此外还需注意的是:对用不同取值范围的特征值是,常常采用数值归一化的方法进行预处理

KNN优缺点:

优点:便于理解
缺点:

  • KNN算法效率不高,每次对测试向量进行分类时,需要保存所有训练数据,而且还需要计算测试向量与训练向量之间的间距。
  • KNN并没有得到任意数据的基础结构信息,无法从中知晓平均实例样本与典型实例样本的特征。
  • 训练数据越多,KNN测试准确率越高,但计算时间也急速增加。

最简单的K近邻法是线性扫描,也就是上面实现的方法,需要计算输入实例与所有训练数据的距离,当训练数据量较大时特别费时。为了提高搜索效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,如kd树。

相应代码实现如下:

import sys;
import numpy as np
sys.path.append("./../Basic/")
import LoadData as ld

def KNN(inX, dataSet, labels, k, iNormNum):
    subtractMat = np.ones([dataSet.shape[0], 1])*np.array(inX).reshape([1, inX.shape[1]]) - dataSet
    distances = ((subtractMat**iNormNum).sum(axis=1))**(1.0/iNormNum)
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=lambda s: s[1], reverse=True)
    return sortedClassCount[0][0]

if __name__ == '__main__':
    iTrainNum      = 60000
    iTestNum       = 10000
    trainDataSet   = ld.getImageDataSet('./../MNISTDat/train-images-idx3-ubyte', iTrainNum, 1, 784)
    trainLabels    = ld.getLabelDataSet('./../MNISTDat/train-labels-idx1-ubyte', iTrainNum)
    testDataSet    = ld.getImageDataSet('./../MNISTDat/t10k-images-idx3-ubyte', iTestNum, 1, 784)
    testLabels     = ld.getLabelDataSet('./../MNISTDat/t10k-labels-idx1-ubyte', iTestNum)
    iErrorNum      = 0
    for iTestInd in range(iTestNum):
        KNNResult = KNN(testDataSet[iTestInd].reshape([1,784]), trainDataSet, trainLabels, 5, 2)
        if (KNNResult != testLabels[iTestInd]): iErrorNum += 1.0
        print("process:%d/%d_totalErrorNum:%d predict_label: %d, real_label: %d" % (iTestInd, iTestNum, iErrorNum, KNNResult, testLabels[iTestInd]))
    print("\ntotal_error_number: %d" % iErrorNum)
    print("\nerror_rate: %f" % (iErrorNum/float(iTestNum)))

文件路径:https://pan.baidu.com/s/1OUB90duLVdRwyS_ZOF4myQ
MNIST数据分类的准确率为96.94%
这里写图片描述

猜你喜欢

转载自blog.csdn.net/drilistbox/article/details/79726654
今日推荐