机器学习-K邻近算法(KNN)

代码:

#encoding=utf-8
#导入鸢尾花数据库
from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV

def knniris():
    #数据集获取和分割
    iris = load_iris()
    x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.25)
    
    #进行标准化
    std = StandardScaler()
    x_train = std.fit_transform(x_train)
    x_test = std.fit_transform(x_test)
    
    #estimator(评估)流程
    knn = KNeighborsClassifier()
    
    #得出模型
    knn.fit(x_train,y_train)
    #进行预测或者得出精度
    y_predict = knn.predict(x_test)
    #验证精确度score
    score = knn.score(x_test,y_test)
    #print(score)
    #通过网络搜索,n_neighbors为参数列表
    param = {"n_neighbors":[3,5,7]}
    gs = GridSearchCV(knn,param_grid=param,cv=10)
    #建立模型
    gs.fit(x_train,y_train)
    #预测数据
    print(gs.score(x_test,y_test))
    #分析模型的准确率和召回率
    print("每个类别的精确率与召回率:",classification_report(y_test,y_predict,target_names=iris.target_names))
    return None
    
if __name__=="__main__":
    knniris()
    

结果:



猜你喜欢

转载自blog.csdn.net/poyue8754/article/details/80723845