关于网格搜索和交叉验证

定义理解

  • 交叉验证:
    • 意义:为了让被评估的模型更加的准确可信(交叉验证平均值最大的才是最可信的),
    • 作用:确定估计器最好的超参数是哪个。
    • 实际操作:把训练集分为几个等份,其中包括一份验证集(类似测试集)和多份训练集。而且你还要知道:这份验证集是不固定的,你分为多少份,验证集就有多少种可能;这份验证集的训练集是剩下的多份训练集之和,而不是某一份训练集
  • 网格搜索:指定模型估计器的超参数,程序自动的帮你使用穷举法来将所用的参数都运行一遍。
  • 网格搜索是要配合交叉验证一起使用的。

sklearn Api 介绍

  • API:from sklearn.model_selection import train_test_split, GridSearchCV
  • GridSearchCV重要参数介绍:estimator:估计器;param_grid:超参数集;cv:交叉验证集和训练集总数

实例code

  • 不可直接run
knn = KNeighborsClassifier()
param = {"n_neighbors": [3, 5, 10]}
# parameter introduction:knn:estimator;pre_dispatch:super parameter;cv:total number of cross validation and training sets
gc = GridSearchCV(estimator=knn, param_grid=param, cv=2)
gc.fit(x_train, y_train)
print("测试集上的准确率:\n", gc.score(x_test, y_test))
print("在交叉验证中最好的结果(平均值):\n", gc.best_score_)
print("选择的最好的参数模型(n_neighbors):\n", gc.best_estimator_)
print("每个超参数每次交叉验证的的结果及最好的结果(平均值):\n", gc.cv_results_)

还是附上输出吧

  • 找一下mean_test_score
测试集上的准确率:
 0.4879432624113475
在交叉验证中最好的结果(平均值):
 0.45783403685125834
选择的最好的参数模型(n_neighbors):
 KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=10, p=2,
                     weights='uniform')
每个超参数每次交叉验证的的结果及最好的结果(平均值)::
 {'mean_fit_time': array([0.01828647, 0.02197655, 0.01896199]), 'std_fit_time': array([0.00046884, 0.00282041, 0.00161343]), 'mean_score_time': array([0.20311983, 0.20907108,
0.24899689]), 'std_score_time': array([0.00736217, 0.00187851, 0.00653006]), 'param_n_neighbors': masked_array(data=[3, 5, 10],
             mask=[False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 10}], 'split0_test_score': array([0.43853428, 0.45933806, 0.45957447]), 'split1_t
est_score': array([0.43059825, 0.45069756, 0.46086545]), 'split2_test_score': array([0.44029321, 0.44927879, 0.45306219]), 'mean_test_score': array([0.43647525, 0.45310481, 0.
45783404]), 'std_test_score': array([0.00421725, 0.00444547, 0.00341512]), 'rank_test_score': array([3, 2, 1])}

发布了55 篇原创文章 · 获赞 3 · 访问量 2721

猜你喜欢

转载自blog.csdn.net/rusi__/article/details/103828328