机器学习:交叉验证和模型选择与Python代码实现

前言:本篇博文主要介绍交叉验证(cross validation)和模型选择,首先介绍相关的基础概念和原理,然后通过Python代码实现交叉验证和模型评估以及选择。特别强调,其中大多理论知识来源于《统计学习方法_李航》斯坦福课程翻译笔记

1.分类器的评价

评价分类器性能的指标一般是分类的准确率,其定义是:对于给定的测试数据集(X_test),分类器正确分类的样本数与总样本数之比。

这里特别强调一下,有一种分类问题——“偏斜分类”——一般是“二分类”问题。这类问题有一个特点就是:某一类的样本数很少,例如病患者中的癌症患者(假如只占了0.5%),那么我现在不需要什么高级算法,只需要对任何人说,“你没癌症”,那么我的正确率理论上都达到了99.5%,所以他不适合简单的计算准确率,而是所谓的“精确率”(precision)和“召回率”(recall)的调和均值:F1 score。(这里不介绍,详情请看前言部分提供的电子书)。

2.分类器的选择

前面的knn算法中提到过,knn分类器里面的K值到底设置多少才是合适的?还有knn算法中,距离函数到底该选择L1范数还是L2范数?当然,可能还有其他需要考虑的选择我们没考虑到。所有这些参数的选择,我们称之为“超参数”(hyperparameter)。

我们一般建议,尝试不同的值,也就是尝试不同的算法,看看哪个算法的性能最好?但是这里有一点需要特别注意!!!!——我们决不能使用测试集(X_test)来进行参数和模型调优。为什么这么说呢,因为如果你用测试集调优,就相当于你用测试集评价挑选出来的最优算法,然后再用测试集来评价,那效果肯定还是最好。但是一旦遇到实际的数据,往往效果就不理想,这称之为“过拟合”(随后聊)。

总之,记住,“测试集”(X_test)只能在最后一次模型评价中使用!

那么,就有了下面的“交叉验证”!

3.交叉验证

一般,如果给定的样本数据充足,我们会随机的从“训练集”(X_train)中提取一部分数据来调优,这个数据集叫做“验证集”(X_validation)。

但是,如果数据不充足,我们往往采用“交叉验证”的方法——把给定的数据进行切分(例如分成5段),将切分的数据集分为“训练集”和“验证集”(假设其中4份为train,1份为validation),在此基础上循环选取(4份为train,1份为validation),进行训练和验证。从而选择其中最好的模型。

ps:实际运用中,因为交叉验证会耗费较多的计算资源(毕竟一直循环选取,然后训练,验证)。所以一般数据充足,直接随机选取一小部分作为“验证集”就行了。(tips:一定要随机,因为如果原来的数据集排列有一定的规律,例如前100全是A类,后100全是B类,那么你只是简单的用语句X_val = X[150:],得到的就只有B类了。当然,如果样本本来就是随机分布的,那就没关系。)

4.交叉验证与模型选择Python代码实现

不多说,上代码,写的累啊~

首先看看数据集,这是一种花的数据集

import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

#load datasets
iris = load_iris()
data = iris.data[:,:2]
target = iris.target
print data.shape#(150,2)
print data[:10]
print target[:10]

label = np.array(target)
index_0 = np.where(label==0)
plt.scatter(data[index_0,0],data[index_0,1],marker='x',color = 'b',label = '0',s = 15)
index_1 =np.where(label==1)
plt.scatter(data[index_1,0],data[index_1,1],marker='o',color = 'r',label = '1',s = 15)
index_2 =np.where(label==2)
plt.scatter(data[index_2,0],data[index_2,1],marker='s',color = 'g',label = '2',s = 15)
plt.xlabel('X1')
plt.ylabel('X2')
plt.legend(loc = 'upper left')
plt.show()
上面代码为了方便可视化,只是选取了两种特征,运行如下:


接着随机抽取一部分作为“测试集”(X_test),留着最后模型评估。然后再用交叉验证,验证正确率。(注意,这个数据集就有分布规律,0-50为0类,50-100为1类,100-150为2类,所以用了随机选取打乱了分布。)

#split the train sets and test sets,
import knn
from sklearn.model_selection import train_test_split
X,X_test,y,y_test = train_test_split(data,target,test_size=0.2,random_state=1)
print X.shape,X_test.shape

#cross validation
folds = 4
k_choices = [1,3,5,7,9,13,15,20,25]

X_folds = []
y_folds = []

X_folds = np.vsplit(X,folds)
y_folds = np.hsplit(y,folds)

accuracy_of_k = {}
for k in k_choices:
    accuracy_of_k[k] = []
#split the train sets and validation sets
for i in range(folds):
    classify = knn.KNearestNeighbor()
    X_train =np.vstack(X_folds[:i] + X_folds[i+1:]) 
    X_val = X_folds[i]
    y_train = np.hstack(y_folds[:i] + y_folds[i+1:])
    y_val = y_folds[i]
    print X_train.shape,X_val.shape,y_train.shape,y_val.shape
    classify.train(X_train,y_train)
    for k in k_choices:
        y_val_pred = classify.predict(X_val,k = k)
        accuracy = np.mean(y_val_pred == y_val)
        accuracy_of_k[k].append(accuracy)

for k in sorted(k_choices):
    for accuracy in accuracy_of_k[k]:
        print 'k = %d,accuracy = %f' %(k,accuracy)
运行如下:

最后可视化一下,选择里面平均值最好的K,

#show the plot
import matplotlib.pyplot as plt
#show the accuracy 
for k in k_choices:
    plt.scatter([k]*len(accuracy_of_k[k]), accuracy_of_k[k])
accuracies_mean = np.array([np.mean(v) for k,v in sorted(accuracy_of_k.items())])
accuracies_std = np.array([np.std(v) for k,v in sorted(accuracy_of_k.items())])
plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)
plt.title('Cross-validation on k')
plt.xlabel('k')
plt.ylabel('Cross-validation accuracy')
plt.show()

最后再用测试集评估,这就得到了你的模型真正的性能!

#we chose the best one
best_k = 13
classify = knn.KNearestNeighbor()
classify.train(X_train,y_train)
y_test_pred = classify.predict(X_test,k=best_k)
num_correct = np.sum(y_test==y_test_pred)
accuracy_test = np.mean(y_test==y_test_pred)
print 'test accuracy is %d/%d = %f' %(num_correct,X_test.shape[0],accuracy_test)
运行得到:test accuracy is 24/30 = 0.800000

总之:交叉验证可以帮助我们进行超参数的调优和模型选择!









猜你喜欢

转载自blog.csdn.net/huakai16/article/details/78070671