(七)XGBoost融合sklearn

API

Scikit-Learn API(Scikit-Learn Wrapper interface for XGBoost.)

import pickle
import xgboost as xgb

import numpy as np
from sklearn.model_selection import KFold, train_test_split, GridSearchCV
from sklearn.metrics import confusion_matrix, mean_squared_error
from sklearn.datasets import load_iris, load_digits, load_boston

1、二分类

rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(2)
y = digits['target']
X = digits['data']
Zeros and Ones from the Digits dataset: binary classification
y.shape,X.shape
((360,), (360, 64))
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
    xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
    predictions = xgb_model.predict(X[test_index])
    actuals = y[test_index]
    print(confusion_matrix(actuals, predictions)) //混淆矩阵
[[87  0]
 [ 1 92]]
[[91  0]
 [ 3 86]]

2、多分类

print("Iris: multiclass classification")
iris = load_iris()
y = iris['target']
X = iris['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
    xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
    predictions = xgb_model.predict(X[test_index])
    actuals = y[test_index]
    print(confusion_matrix(actuals, predictions))
Iris: multiclass classification
[[19  0  0]
 [ 0 31  3]
 [ 0  1 21]]
[[31  0  0]
 [ 0 16  0]
 [ 0  3 25]]

3、回归

print("Boston Housing: regression")
boston = load_boston()
y = boston['target']
X = boston['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
    xgb_model = xgb.XGBRegressor().fit(X[train_index], y[train_index])
    predictions = xgb_model.predict(X[test_index])
    actuals = y[test_index]
    print(mean_squared_error(actuals, predictions))
Boston Housing: regression
9.862814929045339
15.989962572880902

4、参数调优

print("Parameter optimization")
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
#参数调优
clf = GridSearchCV(xgb_model,
                   {'max_depth': [2,4,6], #基分类器的最大树深度                  
                    'n_estimators': [50,100,200]}, verbose=1) #n_estimators:增加树木数量以拟合。 # verbose控制冗长:越高,消息越多
Parameter optimization
clf.fit(X,y)
print(clf.best_score_)
print(clf.best_params_)
Fitting 3 folds for each of 9 candidates, totalling 27 fits
0.5984879606490934
{'max_depth': 4, 'n_estimators': 100}


[Parallel(n_jobs=1)]: Done  27 out of  27 | elapsed:    1.1s finished
# The sklearn API models are picklable
print("Pickling sklearn API models")
# must open in binary format to pickle
pickle.dump(clf, open("best_boston.pkl", "wb"))
clf2 = pickle.load(open("best_boston.pkl", "rb"))
print(np.allclose(clf.predict(X), clf2.predict(X)))
Pickling sklearn API models
True

5、提前终止

# Early-stopping

X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = xgb.XGBClassifier()
clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
        eval_set=[(X_test, y_test)])
[0] validation_0-auc:0.999497
Will train until validation_0-auc hasn't improved in 10 rounds.
[1] validation_0-auc:0.999497
[2] validation_0-auc:0.999497
[3] validation_0-auc:0.999749
[4] validation_0-auc:0.999749
[5] validation_0-auc:0.999749
[6] validation_0-auc:0.999749
[7] validation_0-auc:0.999749
[8] validation_0-auc:0.999749
[9] validation_0-auc:0.999749
[10]    validation_0-auc:1
[11]    validation_0-auc:1
[12]    validation_0-auc:1
[13]    validation_0-auc:1
[14]    validation_0-auc:1
[15]    validation_0-auc:1
[16]    validation_0-auc:1
[17]    validation_0-auc:1
[18]    validation_0-auc:1
[19]    validation_0-auc:1
[20]    validation_0-auc:1
Stopping. Best iteration:
[10]    validation_0-auc:1






XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=3, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)

猜你喜欢

转载自blog.csdn.net/hao5335156/article/details/81176214
今日推荐