sklearn KNN实现鸢尾花分类

基于sklearn的KNN算法实现鸢尾花的分类

数据集下载:GitHub

1.数据准备

  学习分类问题,鸢尾花数据集是比较常用的样例。本文使用的是原始数据,有效数据总共150条,内容和格式未做修改。
  前10行的数据如下所示:
样例数据

2.导入几个包

  包括pandas和sklearn的几个包:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

3.划分训练集和测试集

  首先,将数据集的前四列提取出来作为特征,最后一列作为分类标签;然后,利用train_test_split()将特征和标签随机的划分为训练集和测试,这里设定测试集的比例为20%,也就是30条;最后,将训练集和测试集的标签转换为一维数组(不转换也可以,只是为了看着方便)。

# 读取数据
iris_data_set = pd.read_csv("D:\\iris.csv")
# x是4列特征
x = iris_data_set.iloc[:, 0:4].values
# y是1列标签
y = iris_data_set.iloc[:, -1].values

# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

# 将特征转为一维数组
y_train = y_train.flatten()
y_test = y_test.flatten()

4.训练模型并预测

  首先,通过调用KNeighborsClassifier()函数建立KNN算法模型,这里的n_neighbors=3表示将K值设为3;然后,输入训练集的特征和分类标签进行训练;最后,将测试集的特征应用在模型上,进行分类并获得分类结果。

# 建模
knn_model = KNeighborsClassifier(n_neighbors=3)
# 训练
knn_model.fit(x_train, y_train)
# 预测
y_pre = knn_model.predict(x_test)

5.结果输出及分析

  将测试集的实际分类和模型预测的分类打印出来,可以直观地进行比较。
  混淆矩阵(confusion matrix)是评价分类模型优劣的重要依据,调用confusion_matrix()即可返回模型的混淆矩阵。
  评价分类模型有很多指标,可以通过classification_report()函数进行输出。

print("正确标签:", y_test)
print("预测结果:", y_pre)

# 混淆矩阵
conf_mat = confusion_matrix(y_test, y_pre)
print(conf_mat)

# 分类指标文本报告(精确率、召回率、F1值等)
print(classification_report(y_test, y_pre))

  最终结果如下:
在这里插入图片描述

6.总结

  可以看出,基于sklearn的API,不用写太多的代码,就能够很方便地进行数据集划分、模型建立、模型训练和分类预测,而且也可以对模型的分类指标进行计算。
  完整的代码如下所示:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

# 读取数据
iris_data_set = pd.read_csv("D:\\iris.csv")
# x是4列特征
x = iris_data_set.iloc[:, 0:4].values
# y是1列标签
y = iris_data_set.iloc[:, 4:].values

# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

# 将特征转为一维数组
y_train = y_train.flatten()
y_test = y_test.flatten()

# 建模、训练、预测
knn_model = KNeighborsClassifier()
knn_model.fit(x_train, y_train)
y_pre = knn_model.predict(x_test)

print("正确标签:", y_test)
print("预测结果:", y_pre)

# 混淆矩阵
conf_mat = confusion_matrix(y_test, y_pre)
print(conf_mat)

# 分类指标文本报告(精确率、召回率、F1值等)
print(classification_report(y_test, y_pre))

扩展学习

  1. Python 基于BP神经网络的鸢尾花分类
  2. 机器学习分类问题指标理解——准确率(accuracy)、精确率(precision)、召回率(recall)、F1-Score、ROC曲线、P-R曲线、AUC面积
  3. Python 多维数据可视化

欢迎关注我的微信公众号:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/michael_f2008/article/details/107574888