使用决策树完成随机森林分类

使用sklearn 的decisiontreeclassfier 函数再数据集上完成随机森林分类。

import numpy as np
from sklearn.model_selection import GridSearchCV
from  sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import ShuffleSplit
from sklearn.datasets import make_moons
from sklearn.metrics import accuracy_score
X,y=make_moons(n_samples=10000,noise=0.4,random_state=42)
xtr,xte,ytr,yte=train_test_split(X,y,test_size=0.2,random_state=42)
#简历有噪声的数据集,然后将数据集拆分为训练集和测试,比例为4-1
xtr,xte,ytr,yte=train_test_split(X,y,test_size=0.2,random_state=42)
p={
    
    'max_leaf_nodes': list(range(2, 100)), 'min_samples_split': [2, 3, 4]}
grid_cv=GridSearchCV(DecisionTreeClassifier(random_state=42),p,n_jobs=-1,verbose=1,cv=3)
#建立网格搜索,寻找最优模型
grid_cv.fit(xtr,ytr)
yp=grid_cv.predict(xte)#使用测试集预测
accuracy_score(yte,yp)#获得分类精度分数

使用月球数据,然后对数据划分测试和训练集,然后使用网格模型查找最优模型,训练数据后查看分类精度。


nt=1000;nr=100;mini=[]
rs=ShuffleSplit(n_splits=nt,test_size=len(xtr)-nr,random_state=42)
for train_id,test_id in rs.split(xtr):
    m_xtr=xtr[train_id]
    m_ytr=ytr[test_id]
    mini.append((m_xtr,m_ytr))
#将数据集拆分成1000个训练集,每个训练集随机挑选100个实例
forest= [clone(grid_cv.best_estimator_) for _ in range(nt)]#获得1000个最优模型的克隆集合
acs=[]
for t,(xm,ym) in zip(forest,mini):
    t.fit(xm,ym)
    yp=t.predict(xte)
    acs.append(accuracy_score(yte,yp))
np.mean(acs)    #对循环训练集获得的分数取均值

y_pr = np.empty([nt, len(xte)], dtype=np.uint8)
for tree_index, tree in enumerate(forest):
    Y_pred[tree_index] = tree.predict(X_test)
from scipy.stats import mode
y_pred_majority_votes, n_votes = mode(Y_pred, axis=0)
#每个测试集实例,生成1000个决策树的预测,并且只保留最频繁的预测

首先使用ShuffleSplit生成1000个训练子集,每个子集包含100个实例,然后使用for循环将每个子集的实例数据添加到一起。
其次将上面训练得出的最优模型克隆1000份,使用for循环,将克隆的最优模型和1000个训练集赋值给t,(xm,ym),然后依次训练,每次训练完的模型都对测试集数据预测,将预测结果和真实测试比较精度,将分数添加入acs。最后对测试分数集合取均值查看一下平均精度。
对于每个测试集实例,生成1000个决策树的预测,并且只保留最频繁的预测,会获得一个比上面使用决策树获得的精度更高的模型,也就是随机森林。
在这里插入图片描述

在这里插入图片描述
下面的是使用1000个决策树模型对训练集的预测精度,上面是单纯使用决策树预测的精度。在这个数据集上大约提升了0.25个百分点左右的精度。

猜你喜欢

转载自blog.csdn.net/lisenby/article/details/108616329
今日推荐