python机器学习交叉验证实例

  • 交叉验证(CrossValidation)是常用的机器学习训练手段,可以有效检验一个模型的泛化能力。交叉验证需要将原始数据集平等地划分为若干份,例如常用的10折交叉验证,10-folds CV 指的是将数据集分为10份,然后进行10次训练,每次取出一份数据作为测试集,剩下的作为训练集,得到10个模型,最终将10个模型的预测值做一个平均。

具体python代码如下:

def plot_cross_val(rf4, train_x, train_y,cv_num,path_out):
    import matplotlib.pyplot as plt
    from sklearn.model_selection import cross_val_score
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    evaluate_vars = ['roc_auc','precision','recall','f1']
    fig = plt.figure()
    for plot_num in range(len(evaluate_vars)):
        ax1 = fig.add_subplot(2, 2, plot_num + 1)
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
                            wspace=0.3, hspace=0.5)
        try:
            scores = cross_val_score(rf4, train_x, train_y, cv=cv_num, scoring=evaluate_vars[plot_num])
        except ValueError:
            scores = np.zeros(10)
        plt.plot(range(10), scores)
        plt.xlabel('num of cv')
        plt.ylabel(evaluate_vars[plot_num])
        plt.xticks(np.arange(0, 10, 1),fontsize=6)
        plt.yticks(np.arange(0, 1.1, 0.2),fontsize=8)
        plt.show()
        tt = 'plot of ' + str(evaluate_vars[plot_num])
        ax1.set_title(tt,fontsize=10)
    plt.savefig(path_out, bbox_inches='tight', dpi=300)  # bbox_inches='tight'帮助删除图片空白部分
    plt.show()

if __name__ == '__main__':
    path_out = 'E:/program'
    plot_cross_val(rf4, train_x, train_y,10,path_out)

效果如下:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_41233157/article/details/107935558