参考:http://blog.csdn.net/han_xiaoyang/article/details/49797143
功能说明:
查看模型是否过拟合:
一般过拟合:随着样本量增加,准确率在训练集上得分比较高,交叉验证集上得分较小,中间gab较大。
参数说明:
rain_sizes, train_scores, test_scores = learning_curve(
输入:
(estimator : 你用的分类器。
title : 表格的标题。
X : 输入的feature,numpy类型
y : 输入的target vector
ylim : tuple格式的(ymin, ymax), 设定图像中纵坐标的最低点和最高点
cv : 做cross-validation的时候,数据分成的份数,其中一份作为cv集,其余n-1份作为training(默认为3份)
n_jobs : 并行的的任务数(默认1))
输出:(train_sizes_abs :训练样本数
train_scores:训练集上准确率
test_scores:交叉验证集上的准确率)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
python示例:
from sklearn.naive_bayes import GaussianNB
import numpy as np
from sklearn.learning_curve import learning_curve #c查看是否过拟合
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.05, 1., 20), verbose=0, plot=True):
"""
画出data在某模型上的learning curve.
参数解释
----------
estimator : 你用的分类器。
title : 表格的标题。
X : 输入的feature,numpy类型
y : 输入的target vector
ylim : tuple格式的(ymin, ymax), 设定图像中纵坐标的最低点和最高点
cv : 做cross-validation的时候,数据分成的份数,其中一份作为cv集,其余n-1份作为training(默认为3份)
n_jobs : 并行的的任务数(默认1)
"""
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes, verbose=verbose)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
if plot:
plt.figure()
plt.title(title)
if ylim is not None:
plt.ylim(*ylim)
plt.xlabel(u"train_sample")
plt.ylabel(u"score")
plt.gca().invert_yaxis()
plt.grid()
plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std,
alpha=0.1, color="b")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std,
alpha=0.1, color="r")
plt.plot(train_sizes, train_scores_mean, 'o-', color="b", label=u"train_score")
plt.plot(train_sizes, test_scores_mean, 'o-', color="r", label=u"cross_validation_score")
plt.legend(loc="best")
plt.draw()
plt.show()
plt.gca().invert_yaxis()
plt.savefig("learn_curve.jpg")
midpoint = ((train_scores_mean[-1] + train_scores_std[-1]) + (test_scores_mean[-1] - test_scores_std[-1])) / 2
diff = (train_scores_mean[-1] + train_scores_std[-1]) - (test_scores_mean[-1] - test_scores_std[-1])
return midpoint, diff
if __name__=='__main__':
X=np.array([[ 1. , -0.12493874, 0.04575749],
[ 0. , -0.30103 , 0.03140846],
[ 1. , -0.17609126, 0.11394335],
[ 1. , -0.30103 , -0.06694679],
[ 1. , -0.30103 , -0.12104369],
[ 1. , -0.23408321, 0.11270428],
[ 1. , 0.19188553, 0.22577904],
[ 1. , -0.23736092, -0.42100531],
[ 0. , 0.21085337, 0.13966199],
[ 1. , -0.06214791, 0.07716595],
[ 1. , 0.14612804, -0.01223446],
[ 1. , 0.1383027 , 0.1217336 ],
[ 1. , -0.30103 , -0.18073616],
[ 0. , 0.02996322, -0.09108047],
[ 0. , 0.05435766, 0.1638568 ],
[ 1. , -0.11394335, 0. ],
[ 1. , 0.06694679, 0.30998484],
[ 0. , 0.64345268, 0.02802872],
[ 1. , 0. , -0.01639042],
[ 0. , 0.11394335, -0.0234811 ],
[ 0. , 0. , 0.18799048],
[ 1. , 0. , 0.10914447],
[ 1. , -0.04139269, 0. ],
[ 0. , 0.18905624, 0.17026172],
[ 1. , -0.14132915, 0.15209098],
[ 0. , 0.30103 , 0.27036118],
[ 1. , 0.22184875, 0.05435766],
[ 0. , 0.34242268, 0.09455611],
[ 1. , -0.20411998, -0.1173856 ],
[ 0. , 0.11394335, 0.01189922],
[ 1. , -0.22184875, -0.01378828],
[ 1. , 0.13262557, 0.14390658],
[ 0. , 0.14612804, 0.13353891]])
y=np.array([1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1,
0, 1, 0, 0, 0, 1, 0, 1, 0, 0])
Gmodel=GaussianNB()
train_sizes, train_scores, test_scores=learning_curve(Gmodel,X,y,train_sizes=[3,6,10],cv=3)
plot_learning_curve(Gmodel, u"learning curve", X, y)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
数据结果
注:测试样本量较少,在样本量为10处存在一些过拟合