【机器学习】CART回归树(预剪枝)—— python3 实现方案

 因为要写GBDT回归树的算法,按照cart分类树的格式写了回归树。

跟《机器学习实战》的回归树不同,这里采用了“限制最大深度” 的方式来预剪枝,实现起来更加方便,虽然效果可能没有后剪枝好,但是用于GBDT应该是够用的。

import numpy as np

class CARTRegression():
    '''
    cart回归树作弱学习器,平方误差函数作损失函数。Loss = 1/2*(y-h(x))**2
    '''
    def caclSE(self, dataSet):
        '''
        计算CART回归树的节点方差Squared Error
        :param dataSet: 数据集,包含目标列。  np.array,m*(n+1)
        :return: 当前节点(目标列)的方差
        '''
        if dataSet.shape[0] == 0:  # 如果输入一个空数据集,则返回0
            return 0
        return np.var(dataSet[:, -1]) * dataSet.shape[0]  # 方差=均方差*样本数量

    def splitDataSet(self, dataSet, feature, value):
        '''
        根据给定特征值,二分数据集。
        :param dataSet: 同上
        :param feature: 待划分特征。因为是处理回归问题,这里我们假定数据集的特征都是连续型
        :param value: 阀值
        :return: 特征值小于等于阀值或大于阀值的两个子数据集. k*(n+1), (m-k)*(n+1)
        '''
        arr1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]  # 利用np.nonzero返回目标样本的索引值
        arr2 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
        return arr1, arr2

    def chooseBestFeature(self, dataSet):
        '''
        通过比较所有节点的方差和,选出方差和最小的特征与对应的特征值
        :param dataSet: 同上
        :return: 最佳划分点和划分值
        '''
        n = dataSet.shape[1] - 1  # m是样本数量,n是特征数量
        minErr = np.inf  # 初始化最小方差为无穷大的正数
        bestFeature, bestValue = 0, 0  # 声明变量的类型
        for feature in range(n):
            values = set(dataSet[:, feature].tolist())  # 选取所有出现过的值作为阀值
            for value in values:
                arr1, arr2 = self.splitDataSet(dataSet, feature, value)
                err1 = self.caclSE(arr1)
                err2 = self.caclSE(arr2)
                newErr = err1 + err2
                # 选取方差和最小的特征和对应的阀值
                if newErr < minErr:
                    minErr = newErr
                    bestFeature = feature
                    bestValue = value
        return bestFeature, bestValue

    def calcLeaf(self, dataSet):
        '''
        计算当前节点的目标列均值(作为当前节点的预测值)
        :param dataSet: 同上
        :return: 目标列均值
        '''
        return np.mean(dataSet[:, -1])

    def createTree(self, dataSet, max_depth=4):
        '''
        创建CART回归树
        :param dataSet: 同上
        :param max_depth: 设定回归树的最大深度,防止无限生长(过拟合)
        :return: 字典形式的cart回归树模型
        '''
        if len(set(dataSet[:, -1].tolist())) == 1:  # 如果当前节点的值都相同,结束递归
            return self.calcLeaf(dataSet)
        if max_depth == 1:  # 如果层数超出设定层数,结束递归
            return self.calcLeaf(dataSet)
        # 创建回归树
        bestFeature, bestValue = self.chooseBestFeature(dataSet)
        mytree = {}
        mytree['FeatureIndex'] = bestFeature  # 存储分割特征值的索引
        mytree['FeatureValue'] = bestValue  # 存储阀值
        lSet, rSet = self.splitDataSet(dataSet, bestFeature, bestValue)
        mytree['left'] = self.createTree(lSet, max_depth - 1)  # 存储左子树的信息
        mytree['right'] = self.createTree(rSet, max_depth - 1)  # 存储右子树的信息

        return mytree

    def predict(self, cartTree, testData):
        '''
        根据训练好的cart回归树,预测待测数据的值
        :param cartTree: 训练好的cart回归树
        :param testData: 待测试数据, 1*n
        :return: 预测值
        '''
        if not isinstance(cartTree, dict):  # 不是字典,意味着到了叶子结点,此时返回叶子结点的值即可
            return cartTree
        featureIndex = cartTree['FeatureIndex']  # 获取回归树的第一层特征索引
        featureVal = testData[featureIndex]  # 根据特征索引找到待测数据对应的特征值, 作为下面是进入左子树还是右子树的依据
        if featureVal <= cartTree['FeatureValue']:
            return self.predict(cartTree['left'], testData)
        elif featureVal > cartTree['FeatureValue']:
            return self.predict(cartTree['right'], testData)


# 以下是测试数据
dataSet = np.array([[1, 5.56],
                    [2, 5.70],
                    [3, 5.91],
                    [4, 6.40],
                    [5, 6.80],
                    [6, 7.05],
                    [7, 8.90],
                    [8, 8.70],
                    [9, 9.00],
                    [10, 9.05]])

cart = CARTRegression()
mytree = cart.createTree(dataSet)
print(mytree)
print(cart.predict(mytree, [10]))

猜你喜欢

转载自blog.csdn.net/zhenghaitian/article/details/81136686
今日推荐