代码:
#encoding=utf-8 #导入鸢尾花数据库 from sklearn.datasets import load_iris from sklearn.cross_validation import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import classification_report from sklearn.model_selection import GridSearchCV def knniris(): #数据集获取和分割 iris = load_iris() x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.25) #进行标准化 std = StandardScaler() x_train = std.fit_transform(x_train) x_test = std.fit_transform(x_test) #estimator(评估)流程 knn = KNeighborsClassifier() #得出模型 knn.fit(x_train,y_train) #进行预测或者得出精度 y_predict = knn.predict(x_test) #验证精确度score score = knn.score(x_test,y_test) #print(score) #通过网络搜索,n_neighbors为参数列表 param = {"n_neighbors":[3,5,7]} gs = GridSearchCV(knn,param_grid=param,cv=10) #建立模型 gs.fit(x_train,y_train) #预测数据 print(gs.score(x_test,y_test)) #分析模型的准确率和召回率 print("每个类别的精确率与召回率:",classification_report(y_test,y_predict,target_names=iris.target_names)) return None if __name__=="__main__": knniris()
结果: