【深度学习】实验04 交叉验证

交叉验证

交叉验证是一种评估和选择机器学习模型性能的常用方法。它将数据集划分为训练集和验证集,并重复多次进行模型训练和性能评估,以获取更稳定和可靠的模型评估结果。

# 导入相关库

# 交叉验证所需函数
from sklearn.model_selection import train_test_split,cross_val_score,cross_validate
# 交叉验证所需子集划分方法
from sklearn.model_selection import KFold,LeaveOneOut,LeavePOut,ShuffleSplit
# 分层分割
from sklearn.model_selection import StratifiedKFold,StratifiedShuffleSplit
# 分组分割
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut,LeavePGroupsOut,GroupShuffleSplit
# 时间序列分割
from sklearn.model_selection import TimeSeriesSplit
# 自带数据集
from sklearn import datasets
# SVM算法
from sklearn import svm
# 预处理模块
from sklearn import preprocessing
# 模型度量
from sklearn.metrics import recall_score

划分

# 加载数据集
iris = datasets.load_iris()
print('样本集大小:', iris.data.shape, iris.target.shape)
print('样本:', iris.data, iris.target)
   样本集大小: (150, 4) (150,)
   样本: [[5.1 3.5 1.4 0.2]
    [4.9 3.  1.4 0.2]
    [4.7 3.2 1.3 0.2]
    [4.6 3.1 1.5 0.2]
    [5.  3.6 1.4 0.2]
    [5.4 3.9 1.7 0.4]
    [4.6 3.4 1.4 0.3]
    [5.  3.4 1.5 0.2]
    [4.4 2.9 1.4 0.2]
    [4.9 3.1 1.5 0.1]
    [5.4 3.7 1.5 0.2]
    [4.8 3.4 1.6 0.2]
    [4.8 3.  1.4 0.1]
    [4.3 3.  1.1 0.1]
    [5.8 4.  1.2 0.2]
    [5.7 4.4 1.5 0.4]
    [5.4 3.9 1.3 0.4]
    [5.1 3.5 1.4 0.3]
    [5.7 3.8 1.7 0.3]
    [5.1 3.8 1.5 0.3]
    [5.4 3.4 1.7 0.2]
    [5.1 3.7 1.5 0.4]
    [4.6 3.6 1.  0.2]
    [5.1 3.3 1.7 0.5]
    [4.8 3.4 1.9 0.2]
    [5.  3.  1.6 0.2]
    [5.  3.4 1.6 0.4]
    [5.2 3.5 1.5 0.2]
    [5.2 3.4 1.4 0.2]
    [4.7 3.2 1.6 0.2]
    [4.8 3.1 1.6 0.2]
    [5.4 3.4 1.5 0.4]
    [5.2 4.1 1.5 0.1]
    [5.5 4.2 1.4 0.2]
    [4.9 3.1 1.5 0.1]
    [5.  3.2 1.2 0.2]
    [5.5 3.5 1.3 0.2]
    [4.9 3.1 1.5 0.1]
    [4.4 3.  1.3 0.2]
    [5.1 3.4 1.5 0.2]
    [5.  3.5 1.3 0.3]
    [4.5 2.3 1.3 0.3]
    [4.4 3.2 1.3 0.2]
    [5.  3.5 1.6 0.6]
    [5.1 3.8 1.9 0.4]
    [4.8 3.  1.4 0.3]
    [5.1 3.8 1.6 0.2]
    [4.6 3.2 1.4 0.2]
    [5.3 3.7 1.5 0.2]
    [5.  3.3 1.4 0.2]
    [7.  3.2 4.7 1.4]
    [6.4 3.2 4.5 1.5]
    [6.9 3.1 4.9 1.5]
    [5.5 2.3 4.  1.3]
    [6.5 2.8 4.6 1.5]
    [5.7 2.8 4.5 1.3]
    [6.3 3.3 4.7 1.6]
    [4.9 2.4 3.3 1. ]
    [6.6 2.9 4.6 1.3]
    [5.2 2.7 3.9 1.4]
    [5.  2.  3.5 1. ]
    [5.9 3.  4.2 1.5]
    [6.  2.2 4.  1. ]
    [6.1 2.9 4.7 1.4]
    [5.6 2.9 3.6 1.3]
    [6.7 3.1 4.4 1.4]
    [5.6 3.  4.5 1.5]
    [5.8 2.7 4.1 1. ]
    [6.2 2.2 4.5 1.5]
    [5.6 2.5 3.9 1.1]
    [5.9 3.2 4.8 1.8]
    [6.1 2.8 4.  1.3]
    [6.3 2.5 4.9 1.5]
    [6.1 2.8 4.7 1.2]
    [6.4 2.9 4.3 1.3]
    [6.6 3.  4.4 1.4]
    [6.8 2.8 4.8 1.4]
    [6.7 3.  5.  1.7]
    [6.  2.9 4.5 1.5]
    [5.7 2.6 3.5 1. ]
    [5.5 2.4 3.8 1.1]
    [5.5 2.4 3.7 1. ]
    [5.8 2.7 3.9 1.2]
    [6.  2.7 5.1 1.6]
    [5.4 3.  4.5 1.5]
    [6.  3.4 4.5 1.6]
    [6.7 3.1 4.7 1.5]
    [6.3 2.3 4.4 1.3]
    [5.6 3.  4.1 1.3]
    [5.5 2.5 4.  1.3]
    [5.5 2.6 4.4 1.2]
    [6.1 3.  4.6 1.4]
    [5.8 2.6 4.  1.2]
    [5.  2.3 3.3 1. ]
    [5.6 2.7 4.2 1.3]
    [5.7 3.  4.2 1.2]
    [5.7 2.9 4.2 1.3]
    [6.2 2.9 4.3 1.3]
    [5.1 2.5 3.  1.1]
    [5.7 2.8 4.1 1.3]
    [6.3 3.3 6.  2.5]
    [5.8 2.7 5.1 1.9]
    [7.1 3.  5.9 2.1]
    [6.3 2.9 5.6 1.8]
    [6.5 3.  5.8 2.2]
    [7.6 3.  6.6 2.1]
    [4.9 2.5 4.5 1.7]
    [7.3 2.9 6.3 1.8]
    [6.7 2.5 5.8 1.8]
    [7.2 3.6 6.1 2.5]
    [6.5 3.2 5.1 2. ]
    [6.4 2.7 5.3 1.9]
    [6.8 3.  5.5 2.1]
    [5.7 2.5 5.  2. ]
    [5.8 2.8 5.1 2.4]
    [6.4 3.2 5.3 2.3]
    [6.5 3.  5.5 1.8]
    [7.7 3.8 6.7 2.2]
    [7.7 2.6 6.9 2.3]
    [6.  2.2 5.  1.5]
    [6.9 3.2 5.7 2.3]
    [5.6 2.8 4.9 2. ]
    [7.7 2.8 6.7 2. ]
    [6.3 2.7 4.9 1.8]
    [6.7 3.3 5.7 2.1]
    [7.2 3.2 6.  1.8]
    [6.2 2.8 4.8 1.8]
    [6.1 3.  4.9 1.8]
    [6.4 2.8 5.6 2.1]
    [7.2 3.  5.8 1.6]
    [7.4 2.8 6.1 1.9]
    [7.9 3.8 6.4 2. ]
    [6.4 2.8 5.6 2.2]
    [6.3 2.8 5.1 1.5]
    [6.1 2.6 5.6 1.4]
    [7.7 3.  6.1 2.3]
    [6.3 3.4 5.6 2.4]
    [6.4 3.1 5.5 1.8]
    [6.  3.  4.8 1.8]
    [6.9 3.1 5.4 2.1]
    [6.7 3.1 5.6 2.4]
    [6.9 3.1 5.1 2.3]
    [5.8 2.7 5.1 1.9]
    [6.8 3.2 5.9 2.3]
    [6.7 3.3 5.7 2.5]
    [6.7 3.  5.2 2.3]
    [6.3 2.5 5.  1.9]
    [6.5 3.  5.2 2. ]
    [6.2 3.4 5.4 2.3]
    [5.9 3.  5.1 1.8]] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
    1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
    2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
    2 2]

1.自定义划分

# 数据集划分
# 交叉验证划分训练集和测试集.test_size为测试集所占的比例
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size = 0.4, random_state = 0)
print('训练集:', X_train, y_train)
print('测试集:', X_test, y_test)
   训练集: [[6.  3.4 4.5 1.6]
    [4.8 3.1 1.6 0.2]
    [5.8 2.7 5.1 1.9]
    [5.6 2.7 4.2 1.3]
    [5.6 2.9 3.6 1.3]
    [5.5 2.5 4.  1.3]
    [6.1 3.  4.6 1.4]
    [7.2 3.2 6.  1.8]
    [5.3 3.7 1.5 0.2]
    [4.3 3.  1.1 0.1]
    [6.4 2.7 5.3 1.9]
    [5.7 3.  4.2 1.2]
    [5.4 3.4 1.7 0.2]
    [5.7 4.4 1.5 0.4]
    [6.9 3.1 4.9 1.5]
    [4.6 3.1 1.5 0.2]
    [5.9 3.  5.1 1.8]
    [5.1 2.5 3.  1.1]
    [4.6 3.4 1.4 0.3]
    [6.2 2.2 4.5 1.5]
    [7.2 3.6 6.1 2.5]
    [5.7 2.9 4.2 1.3]
    [4.8 3.  1.4 0.1]
    [7.1 3.  5.9 2.1]
    [6.9 3.2 5.7 2.3]
    [6.5 3.  5.8 2.2]
    [6.4 2.8 5.6 2.1]
    [5.1 3.8 1.6 0.2]
    [4.8 3.4 1.6 0.2]
    [6.5 3.2 5.1 2. ]
    [6.7 3.3 5.7 2.1]
    [4.5 2.3 1.3 0.3]
    [6.2 3.4 5.4 2.3]
    [4.9 3.  1.4 0.2]
    [5.7 2.5 5.  2. ]
    [6.9 3.1 5.4 2.1]
    [4.4 3.2 1.3 0.2]
    [5.  3.6 1.4 0.2]
    [7.2 3.  5.8 1.6]
    [5.1 3.5 1.4 0.3]
    [4.4 3.  1.3 0.2]
    [5.4 3.9 1.7 0.4]
    [5.5 2.3 4.  1.3]
    [6.8 3.2 5.9 2.3]
    [7.6 3.  6.6 2.1]
    [5.1 3.5 1.4 0.2]
    [4.9 3.1 1.5 0.1]
    [5.2 3.4 1.4 0.2]
    [5.7 2.8 4.5 1.3]
    [6.6 3.  4.4 1.4]
    [5.  3.2 1.2 0.2]
    [5.1 3.3 1.7 0.5]
    [6.4 2.9 4.3 1.3]
    [5.4 3.4 1.5 0.4]
    [7.7 2.6 6.9 2.3]
    [4.9 2.4 3.3 1. ]
    [7.9 3.8 6.4 2. ]
    [6.7 3.1 4.4 1.4]
    [5.2 4.1 1.5 0.1]
    [6.  3.  4.8 1.8]
    [5.8 4.  1.2 0.2]
    [7.7 2.8 6.7 2. ]
    [5.1 3.8 1.5 0.3]
    [4.7 3.2 1.6 0.2]
    [7.4 2.8 6.1 1.9]
    [5.  3.3 1.4 0.2]
    [6.3 3.4 5.6 2.4]
    [5.7 2.8 4.1 1.3]
    [5.8 2.7 3.9 1.2]
    [5.7 2.6 3.5 1. ]
    [6.4 3.2 5.3 2.3]
    [6.7 3.  5.2 2.3]
    [6.3 2.5 4.9 1.5]
    [6.7 3.  5.  1.7]
    [5.  3.  1.6 0.2]
    [5.5 2.4 3.7 1. ]
    [6.7 3.1 5.6 2.4]
    [5.8 2.7 5.1 1.9]
    [5.1 3.4 1.5 0.2]
    [6.6 2.9 4.6 1.3]
    [5.6 3.  4.1 1.3]
    [5.9 3.2 4.8 1.8]
    [6.3 2.3 4.4 1.3]
    [5.5 3.5 1.3 0.2]
    [5.1 3.7 1.5 0.4]
    [4.9 3.1 1.5 0.1]
    [6.3 2.9 5.6 1.8]
    [5.8 2.7 4.1 1. ]
    [7.7 3.8 6.7 2.2]
    [4.6 3.2 1.4 0.2]] [1 0 2 1 1 1 1 2 0 0 2 1 0 0 1 0 2 1 0 1 2 1 0 2 2 2 2 0 0 2 2 0 2 0 2 2 0
    0 2 0 0 0 1 2 2 0 0 0 1 1 0 0 1 0 2 1 2 1 0 2 0 2 0 0 2 0 2 1 1 1 2 2 1 1
    0 1 2 2 0 1 1 1 1 0 0 0 2 1 2 0]
   测试集: [[5.8 2.8 5.1 2.4]
    [6.  2.2 4.  1. ]
    [5.5 4.2 1.4 0.2]
    [7.3 2.9 6.3 1.8]
    [5.  3.4 1.5 0.2]
    [6.3 3.3 6.  2.5]
    [5.  3.5 1.3 0.3]
    [6.7 3.1 4.7 1.5]
    [6.8 2.8 4.8 1.4]
    [6.1 2.8 4.  1.3]
    [6.1 2.6 5.6 1.4]
    [6.4 3.2 4.5 1.5]
    [6.1 2.8 4.7 1.2]
    [6.5 2.8 4.6 1.5]
    [6.1 2.9 4.7 1.4]
    [4.9 3.1 1.5 0.1]
    [6.  2.9 4.5 1.5]
    [5.5 2.6 4.4 1.2]
    [4.8 3.  1.4 0.3]
    [5.4 3.9 1.3 0.4]
    [5.6 2.8 4.9 2. ]
    [5.6 3.  4.5 1.5]
    [4.8 3.4 1.9 0.2]
    [4.4 2.9 1.4 0.2]
    [6.2 2.8 4.8 1.8]
    [4.6 3.6 1.  0.2]
    [5.1 3.8 1.9 0.4]
    [6.2 2.9 4.3 1.3]
    [5.  2.3 3.3 1. ]
    [5.  3.4 1.6 0.4]
    [6.4 3.1 5.5 1.8]
    [5.4 3.  4.5 1.5]
    [5.2 3.5 1.5 0.2]
    [6.1 3.  4.9 1.8]
    [6.4 2.8 5.6 2.2]
    [5.2 2.7 3.9 1.4]
    [5.7 3.8 1.7 0.3]
    [6.  2.7 5.1 1.6]
    [5.9 3.  4.2 1.5]
    [5.8 2.6 4.  1.2]
    [6.8 3.  5.5 2.1]
    [4.7 3.2 1.3 0.2]
    [6.9 3.1 5.1 2.3]
    [5.  3.5 1.6 0.6]
    [5.4 3.7 1.5 0.2]
    [5.  2.  3.5 1. ]
    [6.5 3.  5.5 1.8]
    [6.7 3.3 5.7 2.5]
    [6.  2.2 5.  1.5]
    [6.7 2.5 5.8 1.8]
    [5.6 2.5 3.9 1.1]
    [7.7 3.  6.1 2.3]
    [6.3 3.3 4.7 1.6]
    [5.5 2.4 3.8 1.1]
    [6.3 2.7 4.9 1.8]
    [6.3 2.8 5.1 1.5]
    [4.9 2.5 4.5 1.7]
    [6.3 2.5 5.  1.9]
    [7.  3.2 4.7 1.4]
    [6.5 3.  5.2 2. ]] [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
    1 1 1 2 0 2 0 0 1 2 2 2 2 1 2 1 1 2 2 2 2 1 2]
# 训练模型
clf = svm.SVC(kernel = 'linear', C = 1).fit(X_train, y_train)
# 计算准确率
print('准确率:', clf.score(X_test, y_test))
准确率: 0.9666666666666667
# 如果涉及到归一化,则在测试集上也要使用训练集模型提取的归一化函数。
# 通过训练集获得归一化函数模型。(也就是先减几,再除以几的函数)。在训练集和测试集上都使用这个归一化函数
scaler = preprocessing.StandardScaler()
X_train_transformed = scaler.fit_transform(X_train)
clf = svm.SVC(kernel = 'linear', C = 1).fit(X_train_transformed, y_train)
X_test_transformed = scaler.fit_transform(X_test)
print('准确率:', clf.score(X_test_transformed, y_test))
准确率: 0.9333333333333333
# 直接调用交叉验证评估模型
clf = svm.SVC(kernel = 'linear', C = 1)
scores = cross_val_score(clf, iris.data, iris.target, cv = 5)
# 打印输出每次迭代的度量值(准确度)
print(scores)
# 获取置信区间。(也就是均值和方差)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
[0.96666667 1.         0.96666667 0.96666667 1.        ]
Accuracy: 0.98 (+/- 0.03)
# 多种度量结果
# precision_macro为精度,recall_macro为召回率
scoring = ['precision_macro', 'recall_macro']
scores = cross_validate(clf, iris.data, iris.target, scoring = scoring, cv = 5, return_train_score = True)
# scores类型为字典。包含训练得分,拟合次数, score-times (得分次数)
sorted(scores.keys())
print('测试结果:', scores)
测试结果: {'fit_time': array([0.00113702, 0.00095534, 0.0007391 , 0.00055671, 0.0003612 ]), 'score_time': array([0.00205898, 0.00153756, 0.00125694, 0.00080943, 0.00079727]), 'test_precision_macro': array([0.96969697, 1.        , 0.96969697, 0.96969697, 1.        ]), 'train_precision_macro': array([0.97674419, 0.97674419, 0.99186992, 0.98412698, 0.98333333]), 'test_recall_macro': array([0.96666667, 1.        , 0.96666667, 0.96666667, 1.        ]), 'train_recall_macro': array([0.975     , 0.975     , 0.99166667, 0.98333333, 0.98333333])}

2.K折交叉验证

K-fold交叉验证是一种经典的模型选择方法,它主要用于评估机器学习模型的性能并选择最佳超参数值。它的基本思想是将训练数据集分成K个子集,然后进行K次训练和验证,每次将其中一个子集作为验证集,其他K-1个子集作为训练集。最终,将K次验证结果的平均值作为模型的性能指标。

K-fold交叉验证的优点是能够更准确地评估模型的性能,因为每个样本都会被用于训练和验证。同时,它还能够更有效地利用有限的数据集,因为它能够充分利用所有的数据进行模型选择。但是,它的计算成本较高,特别是在数据集较大时,训练和验证的时间会很长。

# K折交叉验证
kf = KFold(n_splits = 2)
for train, test in kf.split(iris.data):
    print("k折划分:%s %s" % (train.shape, test.shape))
    break
k折划分:(75,) (75,)

3.留一交叉验证

留一交叉验证(Leave-one-out cross-validation, LOOCV)是一种特殊的K-fold交叉验证方法,其中K等于训练数据集的大小。在每次训练中,留一交叉验证将一个样本作为验证集,将其余的样本作为训练集,然后用模型对验证集进行预测,最终用所有的预测结果计算模型的性能指标。

留一交叉验证的优点是能够非常准确地评估模型的性能,因为每个样本都会被用于验证。但是,它的计算成本非常高,特别是在数据集非常大的情况下,需要进行大量的训练和验证操作。因此,留一交叉验证通常只在数据集非常小的情况下使用,而在一般情况下,K-fold交叉验证通常是更好的选择。

#留一交叉验证
loo = LeaveOneOut()
for train, test in loo.split(iris.data):
    print("留一划分:%s %s" % (train.shape, test.shape))
    break
留一划分:(149,) (1,)

4.留p交叉验证

留p交叉验证(Leave-p-out Cross-validation,LPOCV)是一种在K-fold交叉验证方法的基础上进行改进的方法。在留p交叉验证中,每次从训练数据集中留出p个样本作为验证集,然后将剩余的样本作为训练集进行模型训练。这个步骤重复进行p次,直到所有的样本都被用于验证过一次。

与留一交叉验证相比,留p交叉验证的计算成本更低,并且可以提供比K-fold交叉验证更准确的性能评估。但是,留p交叉验证仍然需要对所有可能的组合进行训练和验证,因此在数据集较大时计算成本仍然很高,通常只在数据集较小的情况下使用。

留p交叉验证通常用于在数据集中存在特定的结构或相关性时,以确保训练集和验证集的样本能够充分表示这种结构或相关性。

# 留p交叉验证
lpo = LeavePOut(p=2)
for train, test in loo.split(iris.data):
    print("留p划分:%s %s" % (train.shape, test.shape))
    break
留p划分:(149,) (1,)

5.随机排列交叉验证

随机排列交叉验证(Random Permutation Cross Validation)是一种基于数据随机排列的交叉验证方法。它的基本思想是将原始数据集随机打乱后,再按照一定比例划分为训练集和测试集,重复这个过程多次,每次划分的训练集和测试集都是不同的,这样可以更准确地评估模型的性能。

具体来说,随机排列交叉验证的步骤如下:

  1. 将原始数据集随机打乱。

  2. 将打乱后的数据集按照一定比例(如70%训练集,30%测试集)划分为训练集和测试集。

  3. 使用训练集训练模型,并计算在测试集上的性能指标(如准确率、召回率等)。

  4. 重复上述过程多次,每次随机打乱数据集并重新划分训练集和测试集。

  5. 将多次测试的性能指标取平均值,作为模型的最终性能评估指标。

随机排列交叉验证可以有效避免数据集中某些特定的排序顺序对模型性能评估造成的影响,同时可以充分利用数据集中的所有样本进行模型训练和测试,提高模型的泛化能力。

# 随机排列交叉验证
ss = ShuffleSplit(n_splits=3, test_size=0.25,random_state=0)
for train_index, test_index in ss.split(iris.data):
    print("随机排列划分:%s %s" % (train.shape, test.shape))
    break
随机排列划分:(149,) (1,)

6.分层K折交叉验证

分层K折交叉验证(Stratified K-Fold Cross Validation)是一种K折交叉验证的变体,它考虑了样本的分布情况。在数据集中,可能存在不同类别的样本数量不均衡的情况,为了保证每个类别的样本在训练集和测试集中的比例相同,可以使用分层K折交叉验证。

分层K折交叉验证的操作步骤如下:

  1. 将数据集按照类别进行划分。

  2. 对于每个类别,将其样本数除以K得到一个整数和一个余数;将整数部分的样本均分成K份,每份样本数相同,余数加到前面的几份中,这样每份的样本数就是整数部分加1的样本数。

  3. 对于每个类别,将其样本按照第2步划分后的份数进行编号。

  4. 对于每一折i(i=1,2,…,K),从每个类别中选择编号为i、i+K、i+2K…、i+(n-1)K的样本,组成这一折的测试集;其余样本组成训练集。

  5. 重复步骤4直到每折都作为测试集,得到K组不同的训练集和测试集。

分层K折交叉验证保证了训练集和测试集中,每个类别的样本比例都相同,能够更准确地评估模型的性能。

# 分层K折交叉验证
skf = StratifiedKFold(n_splits=3)  #各个类别的比例大致和完整数据集中相同
for train, test in skf.split(iris.data, iris.target):
    print("分层K折划分:%s %s" % (train.shape, test.shape))
    break
分层K折划分:(99,) (51,)

7.分层随机交叉验证

分层随机交叉验证(Stratified Shuffle Split Cross-Validation)是一种交叉验证方法,它将数据集分为k个不重叠的折叠,每次用其中一折作为测试集,剩余的k-1折作为训练集。不同于普通随机交叉验证,分层随机交叉验证会保证每一折中各类别样本的比例相同。

具体的步骤如下:

  1. 将数据集中的样本按照类别进行划分;
  2. 对于每个类别,从该类别的样本中按照一定比例随机抽取样本,组成测试集;
  3. 将不在测试集中的其余样本组成训练集。

分层随机交叉验证可以应用于处理分类问题,以评估模型的性能并调整模型的超参数。比如在使用支持向量机(SVM)训练分类器时,可以利用分层随机交叉验证来选择最佳的C和gamma值来优化模型。

# 分层随机交叉验证
skf = StratifiedShuffleSplit(n_splits=3)  # 划分中每个类的比例和完整数据集中的相同
for train, test in skf.split(iris.data, iris.target):
    print("分层随机划分:%s %s" % (train.shape, test.shape))
    break
分层随机划分:(135,) (15,)

分割

X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]

1.组 k-fold分割

# k折分组
gkf = GroupKFold(n_splits=3)  # 训练集和测试集属于不同的组
for train, test in gkf.split(X, y, groups=groups):
    print("组 k-fold分割:%s %s" % (train, test))
组 k-fold分割:[0 1 2 3 4 5] [6 7 8 9]
组 k-fold分割:[0 1 2 6 7 8 9] [3 4 5]
组 k-fold分割:[3 4 5 6 7 8 9] [0 1 2]

2.留一组分割

# 留一分组
logo = LeaveOneGroupOut()
for train, test in logo.split(X, y, groups=groups):
    print("留一组分割:%s %s" % (train, test))
留一组分割:[3 4 5 6 7 8 9] [0 1 2]
留一组分割:[0 1 2 6 7 8 9] [3 4 5]
留一组分割:[0 1 2 3 4 5] [6 7 8 9]

3.留 P 组分割

# 留p分组
lpgo = LeavePGroupsOut(n_groups=2)
for train, test in lpgo.split(X, y, groups=groups):
    print("留 P 组分割:%s %s" % (train, test))
留 P 组分割:[6 7 8 9] [0 1 2 3 4 5]
留 P 组分割:[3 4 5] [0 1 2 6 7 8 9]
留 P 组分割:[0 1 2] [3 4 5 6 7 8 9]

4.随机分割

# 随机分组
gss = GroupShuffleSplit(n_splits=4, test_size=0.5, random_state=0)
for train, test in gss.split(X, y, groups=groups):
    print("随机分割:%s %s" % (train, test))

随机分割:[0 1 2] [3 4 5 6 7 8 9]
随机分割:[3 4 5] [0 1 2 6 7 8 9]
随机分割:[3 4 5] [0 1 2 6 7 8 9]
随机分割:[3 4 5] [0 1 2 6 7 8 9]

5.时间序列分割

# 时间序列分割
tscv = TimeSeriesSplit(n_splits=3)
TimeSeriesSplit(max_train_size=None, n_splits=3)
for train, test in tscv.split(iris.data):
    print("时间序列分割:%s %s" % (train, test))
时间序列分割:[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38] [39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
 63 64 65 66 67 68 69 70 71 72 73 74 75]
时间序列分割:[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] [ 76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93
  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111
 112]
时间序列分割:[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112] [113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
 149]

附:系列文章

序号 文章目录 直达链接
1 波士顿房价预测 https://want595.blog.csdn.net/article/details/132181950
2 鸢尾花数据集分析 https://want595.blog.csdn.net/article/details/132182057
3 特征处理 https://want595.blog.csdn.net/article/details/132182165
4 交叉验证 https://want595.blog.csdn.net/article/details/132182238
5 构造神经网络示例 https://want595.blog.csdn.net/article/details/132182341
6 使用TensorFlow完成线性回归 https://want595.blog.csdn.net/article/details/132182417
7 使用TensorFlow完成逻辑回归 https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard案例 https://want595.blog.csdn.net/article/details/132182584
9 使用Keras完成线性回归 https://want595.blog.csdn.net/article/details/132182723
10 使用Keras完成逻辑回归 https://want595.blog.csdn.net/article/details/132182795
11 使用Keras预训练模型完成猫狗识别 https://want595.blog.csdn.net/article/details/132243928
12 使用PyTorch训练模型 https://want595.blog.csdn.net/article/details/132243989
13 使用Dropout抑制过拟合 https://want595.blog.csdn.net/article/details/132244111
14 使用CNN完成MNIST手写体识别(TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 使用CNN完成MNIST手写体识别(Keras) https://want595.blog.csdn.net/article/details/132244552
16 使用CNN完成MNIST手写体识别(PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 使用GAN生成手写数字样本 https://want595.blog.csdn.net/article/details/132244764
18 自然语言处理 https://want595.blog.csdn.net/article/details/132276591

猜你喜欢

转载自blog.csdn.net/m0_68111267/article/details/132182238
今日推荐