学习曲线检查模型欠拟合&过拟合

from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit
import numpy as np

def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
                        n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
    plt.figure()
    plt.title(title)
    if ylim is not None:
        plt.ylim(*ylim)
        plt.xlabel("Training examples")
        plt.ylabel("Score")
        train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
        train_scores_mean = np.mean(train_scores, axis=1)
        train_scores_std = np.std(train_scores, axis=1)
        test_scores_mean = np.mean(test_scores, axis=1)
        test_scores_std = np.std(test_scores, axis=1)
        plt.grid()
        
        plt.fill_between(
        train_sizes, train_scores_mean - train_scores_std,
        train_scores_mean + train_scores_std, alpha=0.1,color="r")
        
        plt.fill_between(
        train_sizes, test_scores_mean - test_scores_std,
        test_scores_mean + test_scores_std, alpha=0.1, color="g")
        
        plt.plot(train_sizes, train_scores_mean, 'o-', color="r",label="Training score")
        plt.plot(train_sizes, test_scores_mean, 'o-', color="g",label="Cross-validation score")
        plt.legend(loc="best")
    return plt

实例(员工离职预测(逻辑回归)):
title:图像的名字。

cv:默认cv=None,如果需要传入则如下:

cv : int, 交叉验证生成器或可迭代的可选项,确定交叉验证拆分策略。cv的可能输入是:
- 无,使用默认的3倍交叉验证,
- 整数,指定折叠数。
- 要用作交叉验证生成器的对象。
- 可迭代的yielding训练/测试分裂。

ShuffleSplit:我们这里设置cv,交叉验证使用ShuffleSplit方法,一共取得100组训练集与测试集,每次的测试集为20%,它返回的是每组训练集与测试集的下标索引,由此可以知道哪些是train,那些是test。

ylim:tuple, shape (ymin, ymax), 可选的。定义绘制的最小和最大y值,这里是(0.7,1.01)。

n_jobs : 整数,可选并行运行的作业数(默认值为1)。windows开多线程需要在"name"==__main__中运行。

title='Learning Curves'
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
estimator =LR     # 建模
plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=1)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43222937/article/details/84939036