sklearn库学习之K-NN算法

k近邻分类与k近邻回归

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
import mglearn
import numpy as np
#############
X,y = mglearn.datasets.make_forge()
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 0)
clf = KNeighborsClassifier(n_neighbors = 3)
clf.fit(X_train,y_train)

print("Test set predictions:{}".format(clf.predict(X_test)))
print("Test set accuracy:{:.2f}".format(clf.score(X_test,y_test)))

fig, axes = plt.subplots(1,3,figsize = (10,3))
for n_neighbors,ax in zip([1,3,9],axes):
    clf = KNeighborsClassifier(n_neighbors = n_neighbors).fit(X,y)
    #画图,决策边界可视化
    mglearn.plots.plot_2d_separator(clf,X,fill = True, eps = 0.5,ax = ax, alpha = 0.4)
    mglearn.discrete_scatter(X[:,0],X[:,1],y,ax = ax)#标点
    
    ax.set_title("{} neighbor(s)".format(n_neighbors))
    ax.set_xlabel("feature 0")
    ax.set_ylabel("feature 1")
    ax.legend(loc = 3)

#############
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
X_train,X_test,y_train,y_test = train_test_split(
    cancer.data,cancer.target,stratify = cancer.target,random_state = 66)
training_accuracy = []
test_accuracy = []
neighbors_settings = range(1,11)

for n_neighbors in neighbors_settings:
    clf = KNeighborsClassifier(n_neighbors = n_neighbors)
    clf.fit(X_train,y_train)
    training_accuracy.append(clf.score(X_train,y_train))
    test_accuracy.append(clf.score(X_test,y_test))

fig, ax = plt.subplots(1,1,figsize = (10,6))
plt.plot(neighbors_settings,training_accuracy, label = "training accuracy")
plt.plot(neighbors_settings,test_accuracy, label = 'test accuracy')
plt.xlabel("n_neighbors")
plt.ylabel("Accuracy")
plt.legend()

##########
X,y = mglearn.datasets.make_wave(n_samples=40)
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 0)

fig,axes = plt.subplots(1,3,figsize=(15,4))
line = np.linspace(-3,3,1000).reshape(-1,1)
for n_neighbors,ax in zip([1,3,9],axes):
    reg = KNeighborsRegressor(n_neighbors = n_neighbors)
    reg.fit(X_train,y_train)
    
    print("Test set predictions:{}".format(reg.predict(X_test)))
    print("Test set accuracy:{:.2f}".format(reg.score(X_test,y_test)))

    ax.plot(line,reg.predict(line))
    
    ax.plot(X_train,y_train,'^',c = mglearn.cm2(0),markersize = 8)
    ax.plot(X_test,y_test,'.',c = mglearn.cm2(1),markersize = 8)
    
    ax.set_title("{}neighbor(s)\n train score:{:.2f} test score:{:.2f}".format(
    n_neighbors,reg.score(X_train,y_train),reg.score(X_test,y_test)))
    ax.set_xlabel('Feature')
    ax.set_ylabel('Target')
    ax.legend(['Model predictions','Training data/target','Test data/target'],loc = 'best')

对于代码中函数用法的疑惑

  1. python中关于图例legend在图外的画法简析
    https://blog.csdn.net/yywan1314520/article/details/53740001/

  2. [python] pandas plot( )画图命令总结
    https://blog.csdn.net/u013084616/article/details/79064408

  3. Python之matplotlib基础
    https://www.cnblogs.com/liutongqing/p/6985805.html

  4. tensorflow的reshape操作tf.reshape()
    https://blog.csdn.net/m0_37592397/article/details/78695318

  5. numpy.linspace使用详解
    https://blog.csdn.net/you_are_my_dream/article/details/53493752

  6. fig,ax = plt.subplots()的理解
    https://www.jianshu.com/p/decf22446316

  7. train_test_split用法
    https://blog.csdn.net/mrxjh/article/details/78481578

  8. make_blobs聚类数据生成器
    https://blog.csdn.net/kevinelstri/article/details/52622960

  9. sklearn提供的自带的数据集 https://www.cnblogs.com/nolonely/p/6980160.html

  10. Python DeprecationWarning 类型错误
    https://blog.csdn.net/qq_38734403/article/details/79779713

猜你喜欢

转载自blog.csdn.net/thj19980720/article/details/83066686