ID3决策树的生成

# coding=utf-8
import math

'''
决策树模型,假设有三个条件
    年龄,有三个选项   1  表示老年人  2 表示中年 人  3 表示青年人
    工作,有两个选项   1   表示有工作  2表示  没有工作
    房子,有两个选项   2  表示有房子   2表示  没有房子
	信贷情况           1表示一般   2表示号  3表示非常好
    输出,有两个选项   1 表示给贷款    2表示  不予贷款
要求: 依次获得每个选项的信息增益
'''

output = []


class Node:
    def __init__(self, desc, child_list_input=[], child_desc=[], child=[]):
        self.desc = desc
        self.child = []
        self.child_list_input = child_list_input
        self.child_desc = child_desc

    def set_child(self, child):
        self.child = child

    def append_child(self, child):
        self.child.append(child)


g_desc = [
    {1: "老年人", 2: "中年人", 3: "青年人"},
    {1: "有工作", 2: "没工作"},
    {1: "有房子", 2: "没有房子"},
    {1: "信贷一般", 2: "信贷好", 3: "信贷非常好"}
]
g_columns = ["年龄", "工作", "房子", "信贷"]
limit_low_shang = 0.2
sample_input = [[3, 2, 2, 1, 2],
                [3, 2, 2, 2, 2],
                [3, 1, 2, 2, 1],
                [3, 1, 1, 1, 1],
                [3, 2, 2, 1, 2],
                [2, 2, 2, 1, 2],
                [2, 2, 2, 2, 2],
                [2, 1, 1, 2, 1],
                [2, 2, 1, 3, 1],
                [2, 2, 1, 3, 1],
                [1, 2, 1, 3, 1],
                [1, 2, 1, 2, 1],
                [1, 1, 2, 2, 1],
                [1, 1, 2, 3, 1],
                [1, 2, 2, 1, 2]]


# 计算 H(D)
def get_proper_column_index(sample_input, desc=""):
    # 获得输出的经验熵
    if len(sample_input) == 0:
        return
    out_put_array = {}
    for i in sample_input:
        index = len(i) - 1
        if i[index] in out_put_array:
            out_put_array[i[index]] = out_put_array[i[index]] + 1
        else:
            out_put_array[i[index]] = 1
    N = len(sample_input)
    for i in out_put_array.keys():
        sample = out_put_array[i]
        out_put_array[i] = {}
        pi = sample / N
        out_put_array[i]['pi'] = pi
        out_put_array[i]['log'] = math.log(pi, 2)
        out_put_array[i]['count'] = sample
    HD = 0
    for i in out_put_array.keys():
        HD += 0 - (out_put_array[i]['pi'] * out_put_array[i]['log'])
    columns_num = len(sample_input[0])
    max_shang = 0
    current_index = 0
    for i in range(0, columns_num - 1):
        result = getTezhengX(sample_input, i, N, HD)
        if result > max_shang:
            max_shang = result
            current_index = i
    global limit_low_shang

    if max_shang < limit_low_shang:
        #print("未能成功分类:" + desc)
        output.append(desc)
        # 返回叶子节点
        return Node(desc, [])
    # 根据 current位置的信息进行样本分割
    result = {}
    global g_desc
    for item in sample_input:
        value = item[current_index]
        if value in result.keys():
            result[value]['item'].append(item)
        else:
            result[value] = {}
            result[value]['item'] = []
            result[value]['desc'] = g_desc[current_index][value]
            result[value]['item'].append(item)
    data_list = []
    desc_list = []
    for item in result.keys():
        data_list.append(result[item]['item'])
        desc_list.append(result[item]['desc'])
    # get_proper_column_index(result[item]['item'], result[item]['desc'])

    return Node(g_columns[current_index], data_list, desc_list)


# 接下来计算年龄特征的信息增益
def getTezhengX(sample_input, index, N, HD):
    result = {}
    for i in sample_input:
        key = i[index]
        if key not in result.keys():
            result[key] = {}
            result[key]['count'] = 0
        # 行元素最后一个是输出
        out = i[len(i) - 1]
        if out in result[key].keys():
            result[key][out] = result[key][out] + 1
        else:
            result[key][out] = 1
        result[key]['count'] = result[key]['count'] + 1
    all_sum = 0
    # result.keys 里面的值是 第i列的不同的值得数组
    for key in result.keys():
        # 当前输出类的概率是
        item = result[key]
        # 去掉count元素
        sum_count = item['count']
        result[key].pop('count')
        item = result[key]
        # 特征X 在所有样本中的概率
        sum_value = sum_count / N
        tmp_sum = 0
        for tmp in item.keys():
            pi = item[tmp] / sum_count
            log = math.log(pi, 2)
            tmp_sum = tmp_sum + pi * log
        sum_value = sum_value * (0 - tmp_sum)
        all_sum = all_sum + sum_value
    return HD - all_sum


def proc_Node(node):
    if len(node.child_list_input) == 0:
        # 说明是叶子节点
        return node
    else:
        node_list = []
        for k, item in enumerate(node.child_list_input):
            # print(node.child_list_input[k])
            tmp = get_proper_column_index(item, node.child_desc[k])
            current_node = proc_Node(tmp)
            node_list.append(current_node)
        node.child = node_list
        return node

def bianli_node(node, depth=0, parent=""):
    if depth == 0:
        print("根节点:"+node.desc)
    else:
        print("父节点:" + parent + "  当前节点:" + node.desc + "  深度:" + str(depth))
    for i, item in enumerate(node.child_list_input):
        bianli_node(node.child[i], depth + 1, node.desc)

root = get_proper_column_index(sample_input)
root = proc_Node(root)

bianli_node(root)

猜你喜欢

转载自blog.csdn.net/dasgk/article/details/81008314
今日推荐