Python机器学习入门: sklearn.learning_curve 训练结果可视化实例(完整代码)

介绍

Scikit-learn提供了learning_curve类,方便获得和训练的可视化相关的数据。例如,如果想要观察训练集使用不同样本数量和训练得分/测试得分的关系,可以使用learning_curve函数可视化,得到训练样本数量——训练/测试得分曲线如下。

本文将具体介绍实现过程。

from sklearn.learning_curve import learning_curve, validation_curve

1.数据集选用load_digit数据集 简单介绍

from sklearn.datasets import load_digits
X = digits['data'] #(1797,64)
Y = digits['target'] #(1797,)

2.采用KNC(KNeighborsClassifier)分类器

from sklearn.neighbors import KNeighborsClassifier
knc = KNeighborsClassifier(n_neighbors=3, algorithm='auto', weights='distance', n_jobs=1)

调用learning_curve函数

train_size, train_scores, test_scores = learning_curve(knc, X, Y, cv=10, scoring='accuracy',#10折交叉验证
                                                       train_sizes = np.linspace(0.1,1.0,5))#5次的训练数量占比

解释:
1.输入
knc:分类器
X: data
Y: 标签
cv: K折交叉验证,本文选择10
scoring: sklearn一般评估方法
机器学习中的precision, recall, accuracy, F值
train_sizes:训练示例的相对或绝对数量 ,例如np.linspace(0.1,1.0,5))表示5次的训练数量占总数据集样本数量的占比
2.输出:
train_size: 根据train_sizes得到训练集样本的数量 ,本文为(5,)
train_scores:对于5种不同数量的训练集,对10折交叉验证的10个训练得分,本文为(5,10)
test_scores:对于5种不同数量的训练集,对10折交叉验证的10个测试得分,本文为(5,10)

函数说明具体可以参考: sklearn中的学习曲线learning_curve函数

处理

1.对于5种不同样本数量的训练集,对10折交叉验证的10个训练/测试得分取平均值(即压缩列)。
2.得到得分范围的上下界

mean_train = np.mean(train_scores,1)  #(5,)
# 得到得分范围的上下界
upper_train = np.clip(mean_train + np.std(train_scores,1),0,1) 
lower_train = np.clip(mean_train - np.std(train_scores,1),0,1)
    
mean_test = np.mean(test_scores,1)
# 得到得分范围的上下界
upper_test = np.clip(mean_test + np.std(test_scores,1),0,1) 
lower_test = np.clip(mean_test - np.std(test_scores,1),0,1)

画图

然后就可以得到训练样本数量——训练/测试得分曲线了。

plt.figure('Fig1')
plt.plot(train_size,mean_train,'ro-',label='train')
plt.plot(train_size,mean_test,'go-',label='test')
##填充上下界的范围
plt.fill_between(train_size,upper_train,lower_train,alpha=0.2,#alpha:覆盖区域的透明度[0,1],其值越大,表示越不透明 
         color='r')                   
plt.fill_between(train_size,upper_test,lower_test,alpha=0.2,#alpha:覆盖区域的透明度[0,1],其值越大,表示越不透明 
         color='g')  
plt.grid()
plt.xlabel('train size')
plt.ylabel('accuracy')
plt.legend(loc='lower right')
plt.title('KNC')
plt.savefig('train number-size.png')
plt.show()

猜你喜欢

转载自blog.csdn.net/qq_36937684/article/details/105948980