机器学习算法之 K近邻算法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZZh1301051836/article/details/89036367

1 算法思想

寻找与待分类样本最近的K个点,投票决定该样本是什么类别。

2 KNN算法的三要素

  1. 距离度量
    欧氏距离:m维空间中两个点之间的直线距离。
    d = ( x i y i ) 2 d = \sqrt{\sum(x_i-y_i)^2}
    曼哈顿距离:曼哈顿大街上一个十字路口到另一个十字路口。
    d = x i y i d = \sum{|x_i-y_i|}
    切比雪夫距离:象棋中的“帅”从一个点移动到另一个点的步数就是。
    d = m a x x i y i d = max{|x_i-y_i|}
    夹角余弦定理:
    c o s = x i y i x i 2 y i 2 cos=\frac{\sum{x_i*y_i}}{\sqrt{\sum{x_i^2}}\sqrt{y_i^2}}

  2. k的选择
    须反复尝试和验证。

  3. 算法优化
    对距离加权:距离更近的近邻赋予更大的权重。
    对样本加权:对更值得信任的样本赋予更大的权重。

3 程序实现

#K近邻算法
import numpy as np
from math import *

from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

#计算欧氏距离
def e_distance(x,y):
    return sqrt(sum(pow(a-b,2) for a,b in zip(x,y)))

# 计算曼哈顿距离:d = sum|x1i-x2i|
def m_distance(x,y):
    return sum(abs(x-y))

#计算切比雪夫距离:象棋 帅 单格跳
def q_distance(x,y):
    return max(abs(x-y))

#夹角余弦定理
def cos_distance(x,y):
    return np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y))

if __name__ == '__main__':
    # x = np.array([1, 1])
    # y = np.array([4, 5])

    # print(e_distance(x,y))
    # print(m_distance(x,y))
    # print(q_distance(x,y))
    # print(cos_distance(x,y))

    #采用鸢尾花数据集
    iris = datasets.load_iris()

    X = iris.data
    Y = iris.target

    X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.2,random_state=1)

    # 创建KNN分类器,并拟合数据
    knn = KNeighborsClassifier()
    knn.fit(X_train,Y_train)

    #在测试集上预测
    predictions = knn.predict(X_test)
    print(accuracy_score(Y_test,predictions))
    print(confusion_matrix(Y_test,predictions))
    print(classification_report(Y_test,predictions))

执行结果:

1.0
[[11  0  0]
 [ 0 13  0]
 [ 0  0  6]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       1.00      1.00      1.00        13
           2       1.00      1.00      1.00         6

   micro avg       1.00      1.00      1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

结果解释:
混淆矩阵请参考:混淆矩阵_百度百科
分类报告:机器学习笔记--classification_report&精确度/召回率/F1值

不知道你懂了没,反正我懂了,所以就不想解释了,2333~

猜你喜欢

转载自blog.csdn.net/ZZh1301051836/article/details/89036367