2、python机器学习基础教程——K近邻算法鸢尾花分类

一、第一个K近邻算法应用:鸢尾花分类

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 加载数据
iris_dataset = load_iris()

# 实例化模型
knn = KNeighborsClassifier(n_neighbors=1)

# 切分训练和测试数据集
X_train,X_test,y_train,y_test = train_test_split(iris_dataset["data"], iris_dataset["target"],random_state=0)

#训练
knn.fit(X_train, y_train)

# 评估模型
print("Test set score:{:.2f}".format(knn.score(X_test,y_test)))

# 预测
X_new = np.array([[5,2.9,1,0.2]])
prediction = knn.predict(X_new)
print("Predicted target name:{}".format(iris_dataset["target_names"][prediction]))

以上代码段包含了应用scikit-learn中人和机器学习算法的核心代码。

fit、predict和score方法是scikit-learn监督学习模型中最常用的接口。

猜你喜欢

转载自yq.aliyun.com/articles/684338