用Python实现CART

1. 离散值特征的分类问题

from collections import Counter,defaultdict

import numpy as np

class Node:
    def __init__(self,feat=-1,val=None,res=None,left=None,right=None):
        self.feat=feat
        self.val=val
        self.res=res
        self.left=left
        self.right=right

    def __repr__(self):
        if self.res is not None:
            return str(self.res)
        return '['+repr(self.left)+repr(self.right)+']'


class CART:
    def __init__(self,epsilon=1e-3,min_sample=1):
        self.epsilon=epsilon
        self.min_sample=min_sample
        self.tree=None

    def getGini(self,y_data):
        counter=Counter(y_data)
        length=y_data.shape[0]
        return 1-sum([(v/length)**2 for v in counter.values() ])

    def getFeatGini(self,*setn):
        num=sum([seti.shape[0] for seti in setn])
        return sum([
            (seti.shape[0]/num)*self.getGini(seti) for seti in setn
        ])

    def bestSplit(self,splits_set,x_data,y_data):
        '''
         返回所有切分点的基尼指数,以字典形式存储。
         键为split,是一个元组,第一个元素为最优切分特征,
         第二个为该特征对应的最优切分值
        '''
        pre_gini=self.getGini(y_data)
        subdata_inds=defaultdict(list)   # 切分点以及相应的样本点的索引
        for split in splits_set:
            for ind,sample in enumerate(x_data):
                if sample[split[0]]==split[1]:
                    subdata_inds[split].append(ind)
        min_gini=1  # 将最小基尼指数初始化为一个最大值
        best_split=None
        best_set=None
        length=y_data.shape[0]
        for split,data_ind in subdata_inds.items():
            set1=y_data[data_ind]
            set2_inds=list(
                set(range(length))-set(data_ind)
            )
            set2=y_data[set2_inds]
            if set1.shape[0]<1 or set2.shape[0]<1:
                continue
            now_gini=self.getFeatGini(set1,set2)
            if now_gini<min_gini:
                min_gini=now_gini
                best_split=split
                best_set=(data_ind,set2_inds)
        if abs(pre_gini-min_gini)<self.epsilon:
            best_split=None
        return best_split,best_set,min_gini

    def buildTree(self,splits_set,x_data,y_data):
        if y_data.shape[0]<self.min_sample:
            return Node(res=Counter(y_data).most_common(1)[0][0])
        best_split, best_set, min_gini=self.bestSplit(splits_set,x_data,y_data)
        if best_split is None:
            return Node(res=Counter(y_data).most_common(1)[0][0])
        splits_set.remove(best_split)
        nodes=[0]*2
        for i in range(2):
            nodes[i]=self.buildTree(splits_set,x_data[best_set[i]],y_data[best_set[i]])
        return Node(feat=best_split[0],val=best_split[1],left=nodes[0],right=nodes[1])

    def fit(self,x_data,y_data):
        splits_set=[]
        for feat in range(x_data.shape[1]):
            unique_vals=np.unique(x_data[:,feat])
            if unique_vals.shape[0]<2:
                continue
            elif unique_vals.shape[0]==2:
                splits_set.append((feat,unique_vals[0]))
            else:
                for val in unique_vals:
                    splits_set.append((feat,val))
        self.tree=self.buildTree(splits_set,x_data,y_data)

    def travel(self, feat, tree):
        if tree.res is not None:
            return tree.res
        else:
            if feat[tree.feat]==tree.val:
                branch=tree.left
            else:
                branch=tree.right
            return self.travel(feat, branch)

    def predict(self,x):
        return np.array([
            self.travel(feat,self.tree) for feat in x
        ])
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from main import CART
x=load_iris().data
y=load_iris().target
x_train, x_test, y_train, y_test=train_test_split(x,y,test_size=0.2,random_state=10)
model=CART()
model.fit(x_train,y_train)
y_pred=model.predict(x_test)
print(y_pred)
print(y_test)

2. 连续值特征的分类问题

from collections import Counter,defaultdict

import numpy as np

class Node:
    def __init__(self,feat=-1,val=None,res=None,left=None,right=None):
        self.feat=feat
        self.val=val
        self.res=res
        self.left=left
        self.right=right

    def __repr__(self):
        if self.res is not None:
            return str(self.res)
        return '['+repr(self.left)+repr(self.right)+']'


class CART:
    def __init__(self,epsilon=1e-3,min_sample=1):
        self.epsilon=epsilon
        self.min_sample=min_sample
        self.tree=None

    def getGini(self,y_data):
        counter=Counter(y_data)
        length=y_data.shape[0]
        return 1-sum([(v/length)**2 for v in counter.values() ])

    def getFeatGini(self,*setn):
        num=sum([seti.shape[0] for seti in setn])
        return sum([
            (seti.shape[0]/num)*self.getGini(seti) for seti in setn
        ])

    def bestSplit(self,splits_set,x_data,y_data):
        '''
         返回所有切分点的基尼指数,以字典形式存储。
         键为split,是一个元组,第一个元素为最优切分特征,
         第二个为该特征对应的最优切分值
        '''
        pre_gini=self.getGini(y_data)
        subdata_inds=defaultdict(list)   # 切分点以及相应的样本点的索引
        for split in splits_set:
            feat,s,e=split
            for ind,sample in enumerate(x_data):
                if s<=sample[feat]<e:
                    subdata_inds[split].append(ind)
        min_gini=1  # 将最小基尼指数初始化为一个最大值
        best_split=None
        best_set=None
        length=y_data.shape[0]
        for split,data_ind in subdata_inds.items():
            set1=y_data[data_ind]
            set2_inds=list(
                set(range(length))-set(data_ind)
            )
            set2=y_data[set2_inds]
            if set1.shape[0]<1 or set2.shape[0]<1:
                continue
            now_gini=self.getFeatGini(set1,set2)
            if now_gini<min_gini:
                min_gini=now_gini
                best_split=split
                best_set=(data_ind,set2_inds)
        if abs(pre_gini-min_gini)<self.epsilon:
            best_split=None
        return best_split,best_set,min_gini

    def buildTree(self,splits_set,x_data,y_data):
        if y_data.shape[0]<self.min_sample:
            return Node(res=Counter(y_data).most_common(1)[0][0])
        best_split, best_set, min_gini=self.bestSplit(splits_set,x_data,y_data)
        if best_split is None:
            return Node(res=Counter(y_data).most_common(1)[0][0])
        feat,s,e=best_split
        splits_set.remove(best_split)
        nodes=[0]*2
        for i in range(2):
            nodes[i]=self.buildTree(splits_set,x_data[best_set[i]],y_data[best_set[i]])
        return Node(feat=feat,val=(s,e),left=nodes[0],right=nodes[1])

    def fit(self,x_data,y_data):
        splits_set=[]
        for feat in range(x_data.shape[1]):
            unique_vals=np.unique(x_data[:,feat])
            bins = np.concatenate([[-np.inf], unique_vals,[np.inf]])
            for i in range(len(bins)-1):
                s=bins[i]
                e=bins[i+1]
                splits_set.append((feat,s,e))
        self.tree=self.buildTree(splits_set,x_data,y_data)

    def travel(self, feat, tree):
        if tree.res is not None:
            return tree.res
        else:
            s,e=tree.val
            if s<=feat[tree.feat]<e:
                branch=tree.left
            else:
                branch=tree.right
            return self.travel(feat, branch)

    def predict(self,x):
        return np.array([
            self.travel(feat,self.tree) for feat in x
        ])
发布了281 篇原创文章 · 获赞 35 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/TQCAI666/article/details/103099759