决策树一一CART算法(第三部分)

决策树一一CART算法(第三部分)


CART-回归树模型

如果输出变量是 连续 的,对应的就是 回归 问题,对于决策树而言,输出的信息一定就是叶子结点,所以需要将连续变量按照一定的要求划分。

回归树模型

假设将输入空间划分成 M个单元, , R 1 , R 2 , . . . . , R M R_1,R_2,....,R_M R1,R2,....,RM,并在每个单元 上有一个固定的输出值 ,回归树模型可以表示为:
f ( x ) = ∑ m = 1 M c m I ( x ∈ R m ) f(x)=\sum_{m=1}^{M} c_{m} I\left(x \in R_{m}\right) f(x)=m=1McmI(xRm)
f ( x ) f(x) f(x) 就是CART回归树模型, c m c_m cm 代表输出的类, I ( x ∈ R m ) I\left(x \in R_{m}\right) I(xRm) 就是指示性函数。

假设输入和输出变量如下表:

输入 R 1 R_1 R1 R 2 R_2 R2 R m R_m Rm
输出 C 1 C_1 C1 C 2 C_2 C2 C M C_M CM

I ( x ∈ R m ) I\left(x \in R_{m}\right) I(xRm)就是指当 ( x ∈ R m ) \left(x \in R_{m}\right) (xRm)时1, x ∉ R m x\notin R_m x/Rm时取0

也就是说当某个输出单元也就是类 C m C_m Cm而言,当输入单元 R m R_m Rm和它一致时就存在,不一致则不存在,把所有输入单元对应的类求和之后,便是最终的回归树模型。

平方误差最小化找切分点

选择第 x j x^{j} xj个变量和取值s ,分别作为切分变量和切分点,并定义两个区域:
R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_{1}(j, s)=x \mid x^{(j)} \leq s R1(j,s)=xx(j)s

R 1 ( j , s ) = x ∣ x ( j ) > s R_{1}(j, s)=x \mid x^{(j)} > s R1(j,s)=xx(j)>s
用「平方误差最小化」来寻找最优切分变量 j和最优切分点 s:
min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c e ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \min _{j, s}\left[\min _{c_{1}} \sum_{x_{i} \in R_{1}(j, s)}\left(y_{i}-c_{1}\right)^{2}+\min _{c_{e}} \sum_{x_{i} \in R_{2}(j, s)}\left(y_{i}-c_{2}\right)^{2}\right] j,sminc1minxiR1(j,s)(yic1)2+ceminxiR2(j,s)(yic2)2
这个公式将输出变量按照输入变量分为了两类,求出每次分类后的各个分类的平方误差最小值之和,相当于整体的最小平方误差。平方误差最小,分类和实际最吻合。

其中:
c ^ 1 = a v e ( y i ∣ x i ∈ R 1 ( j , s ) ) \hat c_1=ave(y_i|x_i\in R_1(j,s)) c^1=ave(yixiR1(j,s))

c ^ 2 = a v e ( y i ∣ x i ∈ R 2 ( j , s ) ) \hat c_2=ave(y_i|x_i\in R_2(j,s)) c^2=ave(yixiR2(j,s))

要让平方误差最小,则每次分类后的 c ^ 1 和 c ^ 2 \hat c_1和\hat c_2 c^1c^2应设置为对应的每个区域内的输出变量的平均值。

回归树步骤

输入:训练数据集D,停止条件
输出:CART决策树T

  1. 从根节点出发,进行操作,构建二叉树:

  2. 结点处的训练数据集为D,计算变量的最优切分点,并选择最优变量。

    min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \min _{j, s}\left[\min _{c_{1}} \sum_{x_{i} \in R_{1}(j, s)}\left(y_{i}-c_{1}\right)^{2}+\min _{c_{2}} \sum_{x_{i} \in R_{2}(j, s)}\left(y_{i}-c_{2}\right)^{2}\right] j,sminc1minxiR1(j,s)(yic1)2+c2minxiR2(j,s)(yic2)2
    ——在第j变量下,对其可能取的每个值s,根据样本点分割成 R 1 R_1 R1 R 2 R_2 R2两部分,计算切分点为s时的平方误差。
    ——选择平方误差最小的那个值作为该变量下的最优切分点。
    ——计算每个变量下的最优切分点,并比较在最优切分下的每个变量的平方误差,选择平 方误羞最小的那个变量,即最优变量。

  3. 根据最优特征与最优切分点(j,s),从现结点生成两个子结点,将训练数据集依变量配到两个子结点中去,得到相应的输出值。
    R 1 ( j , s ) = x ∣ x ( j ) ≤ s , R 1 ( j , s ) = x ∣ x ( j ) > s R_{1}(j, s)=x \mid x^{(j)} \leq s ,R_{1}(j, s)=x \mid x^{(j)} > s R1(j,s)=xx(j)sR1(j,s)=xx(j)>s

    c ^ m = 1 N m ∑ x i ∈ R m ( j , s ) y i , x ∈ R m , m = 1 , 2 \hat{c}_{m}=\frac{1}{N_{m}} \sum_{x_{i} \in R_{m}(j, s)} y_{i}, \quad x \in R_{m}, \quad m=1,2 c^m=Nm1xiRm(j,s)yi,xRm,m=1,2

  4. 继续对两个子区域调用上述步骤,直至满足停止条件,即生成CART决策树。

f ( x ) = ∑ m = 1 M c m I ( x ∈ R m ) f(x)=\sum_{m=1}^{M} c_{m} I\left(x \in R_{m}\right) f(x)=m=1McmI(xRm)


回归树例题

拿西瓜为例,对于甜度这个特征,可以分成好吃和不好吃的西瓜两类,可以把连续变量的样本拿来进行划分:

「输入」:用 [0,0.5]来表示由不甜到甜的程度

「输出」:用 [1,10]来表示由不好吃到好吃的程度

甜度 0.05 0.15 0.25 0.35 0.45
好吃的程度 5.5 7.6 9.5 9.7 8.2

由于CART算法是二叉树,所以我们[每次划分只能划分成两类

比如:甜度 ≤ 0.1 \leq0.1 0.1和甜度>0.1 这样两类,然后可以再继续在甜度 >0.1这个范围内,以此类推,选择最优切分点继续划分。

第一次划分:以甜度s=0.1 进行划分

R 1 R_{1} R1

甜度 0.05
好吃程度 5.5

R 2 R_{2} R2

甜度 0.15 0.25 0.35 0.45
好吃的程度 7.6 9.5 9.7 8.2

R 1 R_{1} R1类平均值 c ^ 1 = 5.5 \hat c_1=5.5 c^1=5.5

R 2 R_{2} R2类平均值 c ^ 2 = 7.6 + 9.5 + 9.7 + 8.2 4 = 8.75 \hat c_2=\frac{7.6+9.5+9.7+8.2}{4}=8.75 c^2=47.6+9.5+9.7+8.2=8.75

代入平方误差公式
min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c e ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] = 0 + ( 7.6 − 8.75 ) 2 + ( 9.5 − 8.75 ) 2 + ( 9.7 − 8.75 ) 2 + ( 8.2 − 8.75 ) 2 = 3.09 \min _{j, s}\left[\min _{c_{1}} \sum_{x_{i} \in R_{1}(j, s)}\left(y_{i}-c_{1}\right)^{2}+\min _{c_{e}} \sum_{x_{i} \in R_{2}(j, s)}\left(y_{i}-c_{2}\right)^{2}\right]\\ =0+(7.6-8.75)^2+(9.5-8.75)^2+(9.7-8.75)^2+(8.2-8.75)^2=3.09 j,sminc1minxiR1(j,s)(yic1)2+ceminxiR2(j,s)(yic2)2=0+(7.68.75)2+(9.58.75)2+(9.78.75)2+(8.28.75)2=3.09
同理

第二次划分:以甜度s=0.2 进行划分,得到平方误差为3.53

第三次划分:以甜度s=0.3进行划分,得到平方误差为9.13

第四次划分:以甜度s=0.4 进行划分,得到平方误差为11.52

比较四次的结果,选择平方误差最小值的的点作为切割点,即s=0.1为切割点

输出CART回归树模型:
f ( n ) = { 5.5 , s ≤ 0.1 8.75 , s > 0.1 f(n)= \begin{cases} 5.5, & s \leq 0.1 \\ 8.75, & s>0.1 \end{cases} f(n)={ 5.5,8.75,s0.1s>0.1
剩下的部分还可以对s>0.1区域进行回归划分。

总结:

通过对【连续】 变量进行划分,转换为【离散】的变量来进行计算,那么就和之前的分类树模型相通,这也就是为什么常见的都是CART分类树模型啦。

猜你喜欢

转载自blog.csdn.net/qq_44795788/article/details/124693678