之前那篇文章,手工打造 1-NN 的算法,来实现水果的分类;
今天来主要总结一下 k-NN的算法原理,以及基于 sklearn 实现鸢尾花分类。
目录
1、k近邻算法原理
2、优缺点分析
3、应用——鸢尾花分类
一、k近邻算法原理
原理图参考 机器学习100天——第7天k-NN
k近邻是一种简单的但也常用的分类算法,也可以应用到回归计算。
k近邻是无参数学习,没有假设函数,它是基于实例的,在一个有监督的环境下使用。
做出预测
首先输入训练样本,然后让你准备测试的样本去循环遍历每一个训练样本,找到K个与测试样本距离最近的训练样本,以这几个训练样本大多数的分类为最后的预测值。
距离度量
- 欧式距离(最常用)坐标点的平方根
- 汉明距离
- 曼哈顿距离
- 闵氏距离
K的取值
k太小:容易受到异常值的影响
k太大:计算成本太高
二、优缺点分析
1、优点
- 简单,易于实现
- 易于理解
- 无需训练
- 无需估计参数
2、缺点
- 惰性算法,对测试样本分类时的计算量大,内存开销大
- 必须指定k值,且k值不易选择
三、鸢尾花分类
问题描述:
用knn算法实现一个鸢尾花的分类器
4个特征为 :花瓣长度、宽度 花萼的长度和宽度
标签为 :花的类别 0、1、2
1、加载数据
from sklearn import datasets
iris = datasets.load_iris()
# print(iris)
2、从数据中提取特征与标签
# 提取特征与标签
# 特征为 :花瓣长度、宽度 花萼的长度和宽度
# 标签为 :花的类别 0、1、2
x = iris['data']
y = iris['target']
# print(x, y)
3、划分数据集
from sklearn.model_selection import train_test_split
# 划分训练集与测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=1/5, random_state=0)
4、加载k近邻分类器以及预测
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
y_pred = knn.predict(x_test) # 预测
print(knn.score(x_test, y_test)) # 评分,返回正确分类的比例
5、根据其中两个特征进行可视化
import matplotlib.pyplot as plt
# 绘制散点图
plt.scatter(iris.data[:, 2], iris.data[:, 3], c=iris.target)
plt.show()
6、运行结果
评分为 0.9666666666666667