K-近邻算法入门程序

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/apollo_miracle/article/details/89188646

k-近邻算法(kNN),它的工作原理是:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。 最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。 

这里首先给出k-近邻算法的伪代码和实际的Python代码,然后详细地解释每行代码的含义。

该函数的功能是使用k-近邻算法将每组数据划分到某个类中,其伪代码如下:

对未知类别属性的数据集中的每个点依次执行以下操作:

(1) 计算已知类别数据集中的点与当前点之间的距离;

(2) 按照距离递增次序排序;

(3) 选取与当前点距离最小的k个点;

(4) 确定前k个点所在类别的出现频率;

(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。

python代码:

from numpy import *
import operator


def create_data_set():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ["A", "A", "B", "B"]
    return group, labels


def kNN(inX, data_set, labels, k):
    # 动态计算数据集的行数
    data_set_size = data_set.shape[0]
    print("data_set_size:", data_set_size)
    # tile(inX, (data_set_size, 1)) #在行方向上重复inX data_set_size次,列1次
    # diff_mat为矩阵相减结果
    diff_mat = tile(inX, (data_set_size, 1)) - data_set
    print("diff_mat:", diff_mat)
    # diff_mat矩阵内每一项求二次方
    sqdiff_mat = diff_mat ** 2
    print("sqdiff_mat:", sqdiff_mat)
    # 矩阵每一行求和
    distances_sq = sqdiff_mat.sum(axis=1)
    print("distances_sq:", distances_sq)
    # 开方
    distances = distances_sq ** 0.5
    print("distances:", distances)
    # 按值的大小对索引升序排列
    sorted_dist_index = distances.argsort()
    print("sorted_dist_index:", sorted_dist_index)
    class_count = {}
    for i in range(k):
        vote_label = labels[sorted_dist_index[i]]
        print("vote_label:", vote_label)
        class_count[vote_label] = class_count.get(vote_label, 0) + 1
        print("class_count[vote_label]:", class_count[vote_label])
    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    print("sorted_class_count:", sorted_class_count)
    return sorted_class_count[0][0]


if __name__ == '__main__':
    group, labels = create_data_set()
    kNN([0, 0], group, labels, 3)

代码结果:

data_set_size: 4
diff_mat: [[-1.  -1.1]
 [-1.  -1. ]
 [ 0.   0. ]
 [ 0.  -0.1]]
sqdiff_mat: [[1.   1.21]
 [1.   1.  ]
 [0.   0.  ]
 [0.   0.01]]
distances_sq: [2.21 2.   0.   0.01]
distances: [1.48660687 1.41421356 0.         0.1       ]
sorted_dist_index: [2 3 1 0]
vote_label: B
class_count[vote_label]: 1
vote_label: B
class_count[vote_label]: 2
vote_label: A
class_count[vote_label]: 1
sorted_class_count: [('B', 2), ('A', 1)]

猜你喜欢

转载自blog.csdn.net/apollo_miracle/article/details/89188646