连续数值属性的CART decision tree

划分数据集:Iris;

数据形如:      

 a    b    c    d           class
0    5.1  3.5  1.4  0.2     Iris-setosa
1    4.9  3.0  1.4  0.2     Iris-setosa
2    4.7  3.2  1.3  0.2     Iris-setosa
3    4.6  3.1  1.5  0.2     Iris-setosa
4    5.0  3.6  1.4  0.2     Iris-setosa,

一共四维属性外加class属性。

划分选择依据:基尼指数

连续数值的处理:对Iris中每一维的连续两个数值元素求平均值,构成(n-1)*4维的划分点集合;对每一维中的划分点集合迭代计算基尼指数,将最小值作为最优划分属性。

需要一提的是:做连续数值划分的决策树特别容易出现过拟合的情况,因为终止条件苛刻,在依据某一维属性比如a<7.2进行二维划分之后,下次选择最优划分属性时a并不像离散决策树中可以排除,因为a还可以变成a<3.5,所以决策树的终止条件要依靠剪枝策略来完善。但是在做出划分结果的时候我已经很开心了,迫不及待的贴出来,剪枝的事情,下次再做吧。

贴代码:

# -*- coding: utf-8 -*-
"""
Created on Wed Sep 20 11:16:37 2017

@author: wjw
"""
import numpy as np
import pandas as pd

def readText(filePath):
    
    lines = open(filePath,'r').readlines()
    data = []
    
    for line in lines:
        dataList = line.split(',')
        data.append([float(dataList[0]),float(dataList[1]),float(dataList[2]),
                     float(dataList[3]),dataList[4].split("\n")[0]])
        
    data = pd.DataFrame(data,columns=["a","b","c","d","class"])
    return data
"""
       a    b    c    d           class
0    5.1  3.5  1.4  0.2     Iris-setosa
1    4.9  3.0  1.4  0.2     Iris-setosa
2    4.7  3.2  1.3  0.2     Iris-setosa
3    4.6  3.1  1.5  0.2     Iris-setosa
4    5.0  3.6  1.4  0.2     Iris-setosa
"""

def binSplitData(data,feature,value):#将数据二分开
    
    data0 = data[data[feature]<=value]
    data1 = data[data[feature]>value]
    binData = [data0,data1] #binData是一个三维list
    
    return binData

def chooseBestFeatureToSplit(data):
    avg_set = process(data)
    gini = calGiniIndex(data,avg_set)
    min_avg,minColumn = getMINGini(gini,avg_set)
    
    return min_avg,minColumn

def tree(data):
    

    
    countList = data.groupby('class').count().iloc[:,0] #得到data数据的class统计量
    if countList[0]==data.shape[0]: #如果样本中的所有元素属于同一类别,把这些数据从要分类的中删除
        print("所属类别是:%s"%(data.iloc[0,-1]))
        return #data.iloc[0,-1] #返回类别
    
    min_avg,minColumn = chooseBestFeatureToSplit(data)
    
    print('现在判断属性%s'%(minColumn))
    
    binData = binSplitData(data,minColumn,min_avg)
    for i in range(2):
        
        if i==0:
            print("if属性%s<=%s"%(minColumn,min_avg)) #下一步划分的前提条件
        elif i==1:
            print("if属性%s>%s"%(minColumn,min_avg))
        
        tree(binData[i])
    return 

def getMINGini(gini,avg_set):
    minV = 10.
    
    for column_index in range(gini.shape[1]):
        gini_column = gini.iloc[:,column_index]
        newmin = min(gini_column)
        if newmin < minV:
            minV = newmin 
            minColumn = column_index
    
    min_avg = avg_set[minColumn][gini.iloc[:,minColumn][gini.iloc[:,minColumn]==minV].index.tolist()]
    
    return min_avg[0],gini.columns[minColumn]

def calGiniIndex(data,avg_set): #计算avg_set对应的每一维属性的gini指数,一并返回。

    giniSet=[]
    
    for index in range(avg_set.shape[0]):
        d = data.iloc[:,index]#iloc,通过索引得到数据
        subavg_set = avg_set[index]
        sub_giniSet = []
        for avg in subavg_set:
            
            ndata = data[d<=avg]  #得到小于平均数的数据
            pdata = data[d>avg]
            subdata =  [ndata,pdata]
            
            
            gini = 0
            for dd in subdata:
                sum_cr = 0
                count_result = dd.iloc[:,[index,-1]].groupby('class').count().iloc[:,0]
                for cr in count_result:
                    sum_cr += (cr/dd.shape[0])**2
                for cr in count_result:
                    gini += (cr/data.shape[0])*(1-sum_cr)
            sub_giniSet.append(gini)
        giniSet.append(sub_giniSet)
    
    giniSet = pd.DataFrame(np.array(giniSet).T,columns=list('abcd'))
    
    return giniSet
        
def process(data): #得到包含n-1个元素的连续值候选划分集合
    avg_set=[]
    for i in range(data.shape[1]-1):
        subavg_set = []
        d = data.iloc[:,i]
        sorted_data = np.sort(d)#从小到大排序,直接返回的是array
        for i in range(0,sorted_data.size-1):
            subavg_set.append((sorted_data[i]+sorted_data[i+1])/2)
        avg_set.append(subavg_set)
    return np.array(avg_set)

if __name__ == "__main__":
    filePath = r"E:\data\iris.txt"
    data= readText(filePath)
    tree(data)
    
运行结果:

现在判断属性c
if属性c<1.9
所属类别是:Iris-setosa
if属性c>1.9
现在判断属性d
if属性d<1.7
现在判断属性c
if属性c<4.9
现在判断属性d
if属性d<1.6
所属类别是:Iris-versicolor
if属性d>1.6
所属类别是:Iris-virginica
if属性c>4.9
现在判断属性d
if属性d<1.5
所属类别是:Iris-virginica
if属性d>1.5
现在判断属性a
if属性a<6.95
所属类别是:Iris-versicolor
if属性a>6.95
所属类别是:Iris-virginica
if属性d>1.7
现在判断属性c
if属性c<4.8
现在判断属性a
if属性a<5.95
所属类别是:Iris-versicolor
if属性a>5.95
所属类别是:Iris-virginica
if属性c>4.8
所属类别是:Iris-virginica

怎么样,还不错吧,我觉得ok~

剪枝优化下一篇做吧~





猜你喜欢

转载自blog.csdn.net/ge_nious/article/details/78063565
今日推荐