sparkml_实战全流程_LogisticRegression(二)

交叉验证
网格搜索

参考:
https://www.jianshu.com/p/20456b512fa7


import pyspark.ml.tuning as tune
# 超参调优:grid search和train-validation splitting # 网格搜索
import pyspark.ml.tuning as tune
​
logistic = cl.LogisticRegression(labelCol='INFANT_ALIVE_AT_REPORT')
grid = tune.ParamGridBuilder()\
    .addGrid(logistic.maxIter, [5,10,50])\
    .addGrid(logistic.regParam, [0.01,0.05,0.3])\
    .build()# 找出模型之间比较的方法
evaluator = ev.BinaryClassificationEvaluator(
    rawPredictionCol='probability',
    labelCol='INFANT_ALIVE_AT_REPORT'
)# 使用K-Fold交叉验证评估各种参数的模型
cv = tune.CrossValidator(
    estimator=logistic,
    estimatorParamMaps=grid,
    evaluator=evaluator,
    numFolds=3
)# 我们不能直接使用数据,所以我们
# 创建一个构建特征的pipeline
pipeline = Pipeline(stages=[encoder, featuresCreator])
birth_train, birth_test = births.randomSplit([0.7,0.3],seed=123) # 重新打开数据进行处理
data_transformer = pipeline.fit(birth_train)
data_test = data_transformer.transform(birth_test)
​
​
# cvModel 返回估计的最佳模型  
# 寻找模型最佳参数组合
​
cvModel = cv.fit(data_transformer.transform(birth_train))
results = cvModel.transform(data_test)# 查看效果
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderROC'}))
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderPR'}))0.735848884034915
0.6959036715961695
# 使用下面的代码可以查看模型最佳参数:
# 查看最佳模型参数
results = [
    (
        [
            {key.name: paramValue} 
            for key, paramValue 
            in zip(
                params.keys(), 
                params.values())
        ], metric
    ) 
    for params, metric 
    in zip(
        cvModel.getEstimatorParamMaps(), 
        cvModel.avgMetrics
    )
]sorted(results, 
       key=lambda el: el[1], 
       reverse=True)[0]# 或者
param_maps = cvModel.getEstimatorParamMaps()
eval_metrics = cvModel.avgMetrics
​
param_res = []for params, metric in zip(param_maps, eval_metrics):
    param_metric = {}
    for key, param_val in zip(params.keys(), params.values()):
        param_metric[key.name]=param_val
    param_res.append((param_metric, metric))sorted(param_res, key=lambda x:x[1], reverse=True)[({'maxIter': 50, 'regParam': 0.01}, 0.7406291618177623),
 ({'maxIter': 10, 'regParam': 0.01}, 0.735580969909943),
 ({'maxIter': 50, 'regParam': 0.05}, 0.7355100622938429),
 ({'maxIter': 10, 'regParam': 0.05}, 0.7351586303619441),
 ({'maxIter': 10, 'regParam': 0.3}, 0.7248698034708339),
 ({'maxIter': 50, 'regParam': 0.3}, 0.7214679272915997),
 ({'maxIter': 5, 'regParam': 0.3}, 0.7180255703028883),
 ({'maxIter': 5, 'regParam': 0.01}, 0.7179304617840288),
 ({'maxIter': 5, 'regParam': 0.05}, 0.7173397593133481)]
发布了273 篇原创文章 · 获赞 1 · 访问量 4687

猜你喜欢

转载自blog.csdn.net/wj1298250240/article/details/103947789