KNN算法的重点以及实现

废话流

这学期选了岳晓冬老师的机器学习基础,这个老师很吊不多说,我很菜这个也不多说。

笨鸟先飞早入林,那就努力学习!

秋季学习选修了python数据分析,算是有一点基础,秋季的时候有写过knn、pca降维以及kmenas聚类的python实现,但仍有不少不理解的地方,借着这次机会一并干掉。

知识点整理

knn是一种基本的分类与回归方法。

knn没有显式的学习过程。

给定一个训练集,其中实例的类别已定。求出待预测实例距离训练集中实例的距离,选择距离最近的k个实例,通过多数表决的方式进行分类预测。

knn有三个基本的要素:1.k值的选择(这个有点重要!!)  2.距离的度量  3.分类决策规则

 k值的选择

1.选择过小:容易发生过拟合,模型过于复杂。预测结果对近邻实例点非常敏感。邻近节点变化就会引起预测结果的变化。

2.选择过大:容易发生欠拟合,模型过于简单。与输入实例距离较远的点也会有影响,使预测发生错误。

代码实现(python)

基本每行代码都有注释,简单易懂。今天先总结到这里,以后有新的领悟再来总结。

 1 import numpy as np
 2 import operator
 3 
 4 #test_data是测试数据集,train_dataset训练数据集,train_label是标签
 5 def knn_classify(test_data, train_dataset, train_label, k):
 6     # 获得训练样本个数
 7     train_dataset_amount = train_dataset.shape[0]
 8     # 生成矩阵test_rep_mat,这个矩阵的行数和train_dataset一致但只有一列,元素为test_data
 9     test_rep_mat = np.tile(test_data, (train_dataset_amount, 1))
10 
11     # 求差,将平方后的数据相加,sum(axis=1)是将一个矩阵的每一行向量内的数据相加,得到一个list,list的元素个数和行数一样;
12     # 开平方,得到欧式距离
13     distance = (np.sum((test_rep_mat - train_dataset) ** 2, axis=1)) ** 0.5
14 
15     # argsort 将元素从小到大排列,得到这个数组元素在distance中的index(索引),dist_index元素内容是distance的索引
16     dist_index = distance.argsort()
17     # 新建一个字典
18     class_count = {}
19     for i in range(k):
20         # 找距离最近的三个点的标签
21         label = train_label[dist_index[i]]
22         # 如果属于某个类,在该类的基础上加1,相当于增加其权重,
23         # 如果不是某个类则新建字典的一个key并且等于1(本来是为0的,后面加了个1)
24         class_count[label] = class_count.get(label, 0) + 1
25     # 降序排列,item是将字典中每对key和value组成一个元组
26     # operator.itemgetter获取key对象第一个域的值(从0开始计数,这个题目意思就是以字典中第二个数为关键值来比较)
27     class_count_list = sorted(class_count.items(), key=operator.itemgetter(0), reverse=True)
28     #返回结果
29     return class_count_list[0][0]
30 
31 
32 # 写一个主函数来测试下
33 if __name__ == '__main__':
34     train_data_set = np.array([[2.2, 1.4], \
35                             [2.4, 2.3], \
36                             [1.1, 3.4], \
37                             [8.3, 7.3], \
38                             [9.2, 8.3], \
39                             [10.2, 11.1], \
40                             [11.2, 9.3]])
41     train_label = ['A', 'A', 'A', 'B', 'B', 'B', 'B']
42     test_data = [4.6, 3.4]
43     print('分类结果为:', knn_classify(test_data, train_data_set, train_label, 3))

结果展示

猜你喜欢

转载自www.cnblogs.com/jlthzy/p/11944908.html
今日推荐