一、简介
之前看论文的时候,看到过grid search进行超参调优,一开始觉得应该是个比较高大上的东西,其实后来一看就是个暴力搜索。如果一个模型一共有3个参数,分别是A、B、C,假如A有3个选择,B有4个选择,C有5个选择,为了选出最好的模型,就需要一个参数一个参数试过去,从这60( 3 ∗ 4 ∗ 5 3*4*5 3∗4∗5)个模型里面找出一个符合自己需求的。
注意:grid search是一种调优的思路,和实现方式是没啥关系的。
你既可以通过简单的for循环,也可以直接用sklearn里面的类sklearn.model_selection.GridSearchCV。
这篇博客我觉得讲得挺好的:调参必备–Grid Search网格搜索
grid search缺点也很明显,比较耗时,所以在实际应用中,需要减少调优参数的个数并且缩小每个参数的搜索范围。
二、代码实现
sklearn官网提供的代码就挺好。sklearn官网参数调优示例
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC
print(__doc__)
# Loading the Digits dataset
digits = datasets.load_digits()
# To apply an classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target
# Split the dataset in two equal parts
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, random_state=0)
# Set the parameters by cross-validation
tuned_parameters = [{
'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
'C': [1, 10, 100, 1000]},
{
'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
scores = ['precision', 'recall']
for score in scores:
print("# Tuning hyper-parameters for %s" % score)
print()
clf = GridSearchCV(
SVC(), tuned_parameters, scoring='%s_macro' % score
)
clf.fit(X_train, y_train)
print("Best parameters set found on development set:")
print()
print(clf.best_params_)
print()
print("Grid scores on development set:")
print()
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
print("%0.3f (+/-%0.03f) for %r"
% (mean, std * 2, params))
print()
print("Detailed classification report:")
print()
print("The model is trained on the full development set.")
print("The scores are computed on the full evaluation set.")
print()
y_true, y_pred = y_test, clf.predict(X_test)
print(classification_report(y_true, y_pred))
print()
# Note the problem is too easy: the hyperparameter plateau is too flat and the
# output model is the same for precision and recall with ties in quality.
三、其他
官网是全英文的,如果嫌麻烦,可以看看这篇博客:【GridSearchCV,CV调优超参数使用】【K最近邻分类器 KNeighborsClassifier 使用】【交叉验证】