Detailed usage of ShuffleSplit function in sklearn (machine learning)

✌ How to use the ShuffleSplit function

1. ✌ Principle

Used to randomly "break up" the sample set and divide it into training set and test set (can be understood as validation set, the same below),
similar to cross-validation

2. ✌ Function form

ShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=None)

3. ✌ Important parameters

n_splits:

Divide the number of copies of the data set, similar to the fold number of KFlod, the default is 10 copies

test_size:

The proportion of the test set to the total sample, such as test_size=0.2, 20% of the divided data set will be used as the test set

random_state:

Random number seed, keep the data set divided each time unchanged

4. ✌ Code example

The learning curve is by drawing the accuracy of the training set and cross-validation for different training set sizes, you can see the performance of the model on the new data, and then determine whether the variance of the model is too high or the deviation is too high, and whether to increase the training set Can reduce overfitting.
Insert picture description here

✌ Guide library
from sklearn.datasets import load_digits # 导入手写数字集

from sklearn.model_selection import learning_curve # 导入学习曲线类
from sklearn.model_selection import learning_curve # 导入数据分割类
✌ Load data
fig,ax=plt.subplots(1,1,figsize=(6,6)) # 设置画布和子图
data=load_digits()
x,y=data.data,data.target # 加载特征矩阵和标签
✌ Drawing
train_sizes,train_scores,test_scores=learning_curve(RandomForestClassifier(n_estimators=50),x,y,cv=ShuffleSplit(n_splits=50,test_size=0.2,random_state=0),n_jobs=4)
# 设置分类器为随机森林,x,y,cv为ShuffleSplit分割模式,cpu同时运算为4个
ax.set_ylim((0.7,1.1)) # 设置子图的纵坐标的范围为(0.7~1.1)
ax.set_xlabel("training examples") # 设置子图的x轴名称
ax.set_ylabel("score")
ax.grid() # 画出网图
ax.plot(train_sizes,np.mean(train_scores,axis=1),'o-',color='r',label='train score')
# 画训练集数据分数,横坐标为用作训练的样本数,纵坐标为不同折下的训练分数的均值
ax.plot(train_sizes,np.mean(test_scores,axis=1),'o-',color='g',label='test score')
ax.legend(loc='best') # 设置图例

plt.show()

Guess you like

Origin blog.csdn.net/m0_47256162/article/details/113763125