sklearn.model_selection.cross_val_score

sklearn.model_selection.cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=’warn’, n_jobs=None, verbose=0, fit_params=None, pre_dispatch=‘2*n_jobs’, error_score=’raise-deprecating’)

作用:Evaluate a score by cross-validation
Parameters:
estimator : estimator object implementing ‘fit’
The object to use to fit the data.
进行拟合的线型模型
X : array-like
The data to fit. Can be for example a list, or an array.
训练数据
y : array-like, optional, default: None
The target variable to try to predict in the case of supervised learning.
label数据
groups : array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into train/test set.
特征的label
scoring : string, callable or None, optional, default: None
A string (see model evaluation documentation) or a scorer callable object / function with signature scorer(estimator, X, y).
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy. Possible inputs for cv are:
•None, to use the default 3-fold cross validation,
•integer, to specify the number of folds in a (Stratified)KFold,
•An object to be used as a cross-validation generator.
•An iterable yielding train, test splits.
交叉验证的折叠数
n_jobs : int or None, optional (default=None)
The number of CPUs to use to do the computation
fit_params : dict, optional
Parameters to pass to the fit method of the estimator.
error_score : ‘raise’ | ‘raise-deprecating’ or numeric
Value to assign to the score if an error occurs in estimator fitting. If set to ‘raise’, the error is raised. If set to ‘raise-deprecating’, a FutureWarning is printed before the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error. Default is ‘raise-deprecating’ but from version 0.22 it will change to np.nan.

Examples:

>>> from sklearn import datasets, linear_model
>>> from sklearn.model_selection import cross_val_score
>>> diabetes = datasets.load_diabetes()
>>> X = diabetes.data[:150]
>>> y = diabetes.target[:150]
>>> lasso = linear_model.Lasso()
>>> print(cross_val_score(lasso, X, y, cv=3))  #常用的四个参数
[0.33150734 0.08022311 0.03531764]

猜你喜欢

转载自blog.csdn.net/Du_Shuang/article/details/84327522
今日推荐