机械学习:k-NN算法(也叫k近邻算法)

一、kNN算法基础

# kNN:k-Nearest Neighboors

# 多用于解决分裂问题

 1)特点:

  1. 思想极度简单;
  2. 应用数学知识少(近乎为零);
  3. 效果少;
  4. 可以解释机械学习算法使用过程中的很多细节问题
  5. 更完整的刻画机械学习应用的流程;

 2)思想:

  • 根本思想:两个样本,如果它们的特征足够相似,它们就有更高的概率属于同一个类别;
  • 问题:根据现有训练数据集,判断新的样本属于哪种类型
  • 方法/思路
  1. 求新样本点在样本空间内与所有训练样本的欧拉距离;
  2. 对欧拉距离排序,找出最近的k个点;
  3. 对k个点分类统计,看哪种类型的点数量最多,此类型即为对新样本的预测类型;

 3)代码实现过程:

  • 具体代码:
    import numpy as np
    import matplotlib.pyplot as plt
    
    raw_data_x = [[3.3935, 2.3312],
                  [3.1101, 1.7815],
                  [1.3438, 3.3684],
                  [3.5823, 4.6792],
                  [2.2804, 2.8670],
                  [7.4234, 4.6965],
                  [5.7451, 3.5340],
                  [9.1722, 2.5111],
                  [7.7928, 3.4241],
                  [7.9398, 0.7916]]
    raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
    
    # 训练集样本的data
    x_train = np.array(raw_data_x)
    # 训练集样本的label
    y_train = np.array(raw_data_y)
    
    # 1)绘制训练集样本与新样本的散点图
    # 根据样本类型(0、1两种类型),绘制所有样本的各特征点
    plt.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], color = 'g')
    plt.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], color = 'r')
    # 新样本
    x = np.array([8.0936, 3.3657])
    # 将新样本的特征点绘制在训练集的样本空间
    plt.scatter(x[0], x[1], color = 'b')
    plt.show()
    
    
    # 2)在特征空间中,计算训练集样本中的所有点与新样本的点的欧拉距离
    from math import sqrt
    # math模块下的sqrt函数:对数值开平方sqrt(number)
    distances = []
    for x_train in x_train:
        d = sqrt(np.sum((x - x_train) ** 2))
        distances.append(d)
    
    # 也可以用list的生成表达式实现:
    # distances = [sqrt(np.sum((x - x_train) ** 2)) for x_train in x_train]
    
    
    # 3)找出距离新样本最近的k个点,并得到对新样本的预测类型
    nearest = np.argsort(distances)
    k = 6
    # 找出距离最近的k个点的类型
    topK_y = [y_train[i] for i in nearest[:k]]
    
    # 根据类别对k个点的数量进行统计
    from collections import Counter
    votes = Counter(topK_y)
    
    # 获取所需的预测类型:predict_y
    predict_y = votes.most_common(1)[0][0]
  • 代码中的其它Python知识:
  1. math模块下的sprt()方法:对数开平方;

    from math import sqrt
    print(sprt(9))
    # 3
  2. collections模块下的Counter()方法:对列表中的数据进行分类统计,生产一个Counter对象;
    from collections import Counter
    
    my_list = [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]
    print(Counter(my_list))
    # 一个Counter对象:Counter({0: 2, 1: 3, 2: 4, 3: 5})
  3. Counter对象的most_common()方法:Counter.most_common(n),返回Counter对象中数量最多的n种数据,返回一个list,list的每个元素为一个tuple;
    from collections import Counter
    
    my_list = [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]
    votes = Counter(my_list)
    print(votes.most_common(2))
    # [(3, 5), (2, 4)]

猜你喜欢

转载自www.cnblogs.com/volcao/p/9072815.html