周志华《机器学习》第四章决策树-编程尝试

周志华《机器学习》第四章决策树-编程尝试

一、导入需要用到的包

'''
import math
import re
from sklearn.externals.six import StringIO  
from pydotplus import graphviz
'''


二、定义node类

'''
class Node(object):
    def __init__(self, attr_init=None, label_init=None, attr_down_init={}):
        self.attr = attr_init
        self.label = label_init
        self.attr_down = attr_down_init
'''


三、标签计数字典

'''
def tagDict(tag_attr):
    tag_Dict = {}

    for tag in tag_attr:
        if tag in tag_Dict:
            tag_Dict[tag] += 1
        else:
            tag_Dict[tag] = 1
    return tag_Dict
'''


四、属性变量计数字典

'''
def predCount(pred_attr):
    pred_Count = {}

    for pred in pred_attr:
        if pred in pred_Count:
            pred_Count[pred] += 1
        else:
            pred_Count[pred] = 1
    return pred_Count
'''


五、信息熵

'''
def Ent(tag_attr):
    try:
        import math
    except ImportError:
        print('module math not found')


    ent_tag = 0
    n = len(tag_attr)
    tag_Dict = tagDict(tag_attr)
    for tag in tag_attr:
        ent_tag -= (tag_Dict[tag] / n) * math.log2(tag_Dict[tag] / n)
    return ent_tag
'''


六、信息增益

'''
def Gain(df, predict, response):

    gain = Ent(df[response])
    div_point = 0
    div_ent = {}

    n = len(df[predict])
    if df[predict].dtype == ('float64', 'int64'):
        numeric_ent = {}

        df = df.sort_values(by=[predict], ascending=1)
        df = df.reset_index(drop=True)

        pred_attr = df[predict]
        tag_attr = df[response]

        for i in range(1, n):
            div_point = (pred_attr[i - 1] + pred_attr[i]) / 2
            div_ent[div_point] = i * Ent(tag_attr[:i]) / n + (n - i) * Ent(tag_attr[i+1:]) / n

        div_point, numeric_ent = min(div_ent.items(), key=lambda x: x[1])
        gain -= numeric_ent

    else:
        pred_attr = df[predict]
        tag_attr = df[response]
        pred_Count = predCount(pred_attr)

        for pred in pred_Count:
            attr_attr = tag_attr[pred_attr == pred]
            gain -= pred_Count[pred] * Ent(attr_attr) /n

    return gain, div_point
'''


七、最优属性

'''
def getOptPred(df, response):
    gain = 0

    for pred in df.drop([response], axis=1):
        gain_tmp, div_point_tmp = Gain(df, pred, response)
        if gain_tmp > gain:
            gain = gain_tmp
            opt_pred = pred
            div_point = div_point_tmp

    return opt_pred, div_point
'''


八、主函数

'''
def Tree(df, response):
    new_node = Node(None, None, {})
    tag_attr = df[response]

    tag_Dict = tagDict(tag_attr)
    if tag_Dict: # assert the label_count isn's empty
        new_node.label = max(tag_Dict, key=tag_Dict.get)

        if len(tag_Dict) == 1 or len(tag_attr) == 0:
            return new_node

        new_node.attr, div_point = getOptPred(df, response) 

        if div_point == 0:
            pred_Count = predCount(df[new_node.attr])
            for attr in pred_Count:
                df_attr = df[df[new_node.attr].isin([attr])]
                df_attr = df_attr.drop([new_node.attr], axis=1)
                new_node.attr_down[attr] = tree(df_attr, response)

        else:
            point_left = "<=%.3f" % div_point
            point_right = ">=%.3f" % div_point
            df_attr_left = df[df[new_node.attr] <= div_point]
            df_attr_right = df[df[new_node.attr] > div_point]

            new_node.attr_down[point_left] = tree(df_attr_left, response)
            new_node.attr_down[point_right] = tree(df_attr_right, response)

    return new_node
'''


九、预测函数

'''
def Predict(root, df_sample):   
    try :
        import re # using Regular Expression to get the number in string
    except ImportError :
        print("module re not found")

    while root.attr != None :        
        if df_sample[root.attr].dtype == ('float64', 'int64'):
            for key in list(root.attr_down):
                num = re.findall(r"\d+\.?\d*",key)
                div_value = float(num[0])
                break
            if df_sample[root.attr].values[0] <= div_value:
                key = "<=%.3f" % div_value
                root = root.attr_down[key]
            else:
                key = ">%.3f" % div_value
                root = root.attr_down[key]

        else:  
            key = df_sample[root.attr].values[0]
            if key in root.attr_down: 
                root = root.attr_down[key]
            else: 
                break

    return root.label 
'''


十、代入数据

'''
import pandas as pd
data_file_encode = "gb18030"  # the watermelon_3.csv is file codec type
with open(".../data/watermelon_3.csv", mode = 'r', encoding = data_file_encode) as data_file:
df = pd.read_csv(data_file)

n = len(df.index)
k = 10
for i in range(k):
    m = int(n/k)
    test = []
    for j in range(i*m, i*m+m):
        test.append(j)

    df_train = df.drop(test)
    df_test = df.iloc[test]
    root = tree(df_train)  # generate the tree

    pred_true = 0
    for i in df_test.index:
        label = predict(root, df[df.index == i])
        if label == df_test[df_test.columns[-1]][i]:
            pred_true += 1

    accuracy = pred_true / len(df_test.index)
    accuracy_scores.append(accuracy) 


accuracy_sum = 0
print("accuracy: ", end = "")
for i in range(k):
    print("%.3f  " % accuracy_scores[i], end = "")
    accuracy_sum += accuracy_scores[i]
print("\naverage accuracy: %.3f" % (accuracy_sum/k))
'''


十一、结果

十次交叉验证结果


扫描二维码关注公众号,回复: 2871664 查看本文章

十二、数据

>
编号 色泽 根蒂 敲声 纹理 脐部 触感 密度 含糖率 好瓜
1 青绿 蜷缩 浊响 清晰 凹陷 硬滑 0.697 0.46 是
2 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 0.774 0.376 是
3 乌黑 蜷缩 浊响 清晰 凹陷 硬滑 0.634 0.264 是
4 青绿 蜷缩 沉闷 清晰 凹陷 硬滑 0.608 0.318 是
5 浅白 蜷缩 浊响 清晰 凹陷 硬滑 0.556 0.215 是
6 青绿 稍蜷 浊响 清晰 稍凹 软粘 0.403 0.237 是
7 乌黑 稍蜷 浊响 稍糊 稍凹 软粘 0.481 0.149 是
8 乌黑 稍蜷 浊响 清晰 稍凹 硬滑 0.437 0.211 是
9 乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 0.666 0.091 否
10 青绿 硬挺 清脆 清晰 平坦 软粘 0.243 0.267 否
11 浅白 硬挺 清脆 模糊 平坦 硬滑 0.245 0.057 否
12 浅白 蜷缩 浊响 模糊 平坦 软粘 0.343 0.099 否
13 青绿 稍蜷 浊响 稍糊 凹陷 硬滑 0.639 0.161 否
14 浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 0.657 0.198 否
15 乌黑 稍蜷 浊响 清晰 稍凹 软粘 0.36 0.37 否
16 浅白 蜷缩 浊响 模糊 平坦 硬滑 0.593 0.042 否
17 青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 0.719 0.103 否

十三、参考

(https://blog.csdn.net/snoopy_yuan/article/details/68959025)

猜你喜欢

转载自blog.csdn.net/NJYR21/article/details/79942308