机器学习——k近邻算法(KNN)

import math
import numpy as np
from collections import Counter
class KNNClassfiy(object):
    def __init__(self,k):
    #判断k有效
        assert k>=1,'k must be valid'
        self.k=k
        self._xTrain=None
        self._yTrain=None


    def fit(self,xTrain,yTrain):
    #判断输入的训练集有效
        assert xTrain.shape[0]==yTrain.shape[0],\
            'The size of xTrain must be equals to the size of yTrain'
    #判断K有效   
        assert self.k<=xTrain.shape[0],\
            'The size of xTrain must be least at k'
        self._xTrain=xTrain
        self._yTrain=yTrain
        return self

    def predict(self,X_predict):
        # X_predict是预测数据数组,判断预测数据合法性,必须是二维数组
        assert X_predict.shape[1]==self._xTrain.shape[1],\
            'The feature of x must be equal to xTrain'
        assert self._xTrain is not None and self._yTrain is not None,\
            'must fit before predict'
        y_predict=[self._predict(x) for  x in X_predict]
        return np.array(y_predict)

    def _predict(self,x):
        distances=[math.sqrt(np.sum((xTrain-x)**2)) for xTrain in self._xTrain]
        nearest=np.argsort(distances)
        top_y=[self._yTrain[i] for i in nearest[:self.k]]
        votes=Counter(top_y)
        print(votes.most_common(1))
        return votes.most_common(1)[0][0]
    def __repr__(self):
        return self.k

KNN_clf=KNNClassfiy(k=6);
#先训练后预测
xTrain=np.array([[4.5,3.2],
                 [5.8,4.1],
                 [6.7,5.3],
                 [8.6,7.1],
                 [3.8,2.5],
                 [5.3,4.4],
                 [9.4,8.6],
                 [11.8,9.4],
                 [3.8,3.2],
                 [12.8,10.1]])
yTrain=np.array([0,0,1,1,0,0,1,1,0,1])
KNN_clf.fit(xTrain=xTrain,yTrain=yTrain)
x_predict=np.array([[6.9,5.7],[3.4,2.8]])
a=KNN_clf.predict(x_predict)
print(a[0],a[1])

代码比较简单,主要逻辑在于预测部分。

调用matplotlib绘制图形分布图

在这里插入图片描述

步骤可简化如下:

  • 确定k值
  • 训练数据集
  • 预测函数

K近邻算法主要解决分类问题,是机器学习中最简单的最基础的一种算法。

发布了123 篇原创文章 · 获赞 74 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/weixin_43927892/article/details/103355712