所谓cart树的生成

# coding=utf-8
import math

'''
CART决策树模型,假设有三个条件
    年龄,有三个选项   1  表示老年人  2 表示中年 人  3 表示青年人
    工作,有两个选项   1   表示有工作  2表示  没有工作
    房子,有两个选项   2  表示有房子   2表示  没有房子
	信贷情况           1表示一般   2表示号  3表示非常好
    输出,有两个选项   1 表示给贷款    2表示  不予贷款
要求: 依次获得每个选项的信息增益
'''

min_row_count = 2

g_desc = [
    {1: "老年人", 2: "中年人", 3: "青年人"},
    {1: "有工作", 2: "没工作"},
    {1: "有房子", 2: "没有房子"},
    {1: "信贷一般", 2: "信贷好", 3: "信贷非常好"}
]
g_columns = ["年龄", "工作", "房子", "信贷"]
sample_input = [[3, 2, 2, 1, 2],
                [3, 2, 2, 2, 2],
                [3, 1, 2, 2, 1],
                [3, 1, 1, 1, 1],
                [3, 2, 2, 1, 2],
                [2, 2, 2, 1, 2],
                [2, 2, 2, 2, 2],
                [2, 1, 1, 2, 1],
                [2, 2, 1, 3, 1],
                [2, 2, 1, 3, 1],
                [1, 2, 1, 3, 1],
                [1, 2, 1, 2, 1],
                [1, 1, 2, 2, 1],
                [1, 1, 2, 3, 1],
                [1, 2, 2, 1, 2]]


class Node:
    def __init__(self, index, value, jini, dataset_1, dataset_2):
        self.index = index
        self.value = value
        self.jini = jini
        self.dataset_1 = dataset_1
        self.dataset_2 = dataset_2
        self.left = None
        self.right = None

    def set_ds1(self, ds1):
        self.dataset_1 = ds1

    def set_ds2(self, ds2):
        self.dataset_2 = ds2

    def set_left(self, left):
        self.left = left

    def set_right(self, right):
        self.right = right


# 计算 right_index列的基尼值
def get_proper_tezheng(sample_input, right_index):
    row_count = len(sample_input)
    distinct_columns = {}
    for row in sample_input:
        current = row[right_index]
        if current not in distinct_columns.keys():
            distinct_columns[current] = {}
            distinct_columns[current]['jini'] = 0

    # 计算得到不同特征值的基尼指数
    min_element = 0
    min_corresp = 9
    for item in distinct_columns.keys():
        dataset_1 = []
        dataset_2 = []
        for row in sample_input:
            if item == row[right_index]:
                dataset_1.append(row)
            else:
                dataset_2.append(row)
        # 计算两个数据集的基尼值
        distinct_columns[item]['jini'] = (len(dataset_1) / row_count) * get_gini(dataset_1) + (
                len(dataset_2) / row_count) * get_gini(dataset_2)
        if distinct_columns[item]['jini'] < min_corresp:
            min_corresp = distinct_columns[item]['jini']
            min_element = item

    # 返回最优切分店
    result = Node(right_index, min_element, min_corresp, [], [])
    dataset_1 = []
    dataset_2 = []
    for item in sample_input:
        if item[right_index] == min_element:
            dataset_1.append(item)
        else:
            dataset_2.append(item)
    result.set_ds1(dataset_1)
    result.set_ds2(dataset_2)
    return result


def get_finall_node(sample_input):
    columns_count = len(sample_input[0])
    finall_result = Node(0, 0, 1, [], [])
    for i in range(0, columns_count - 1):
        current = get_proper_tezheng(sample_input, i)
      #  print(current.index)
     #   print(current.jini)
        if current.jini < finall_result.jini:
            finall_result = current

    #print(finall_result.index)
    #print(finall_result.jini)
    return finall_result


# 获得一个矩阵的基尼值
def get_gini(dataset):
    result = {}
    column_count = len(dataset)
    # 获得当前矩阵的输出结果,和每个结果的值的数量
    for item in dataset:
        output = item[len(item) - 1]
        if output not in result:
            result[output] = 1
        else:
            result[output] = result[output] + 1
    sum_fenshu = 0
    for item in result.keys():
        sum_fenshu = sum_fenshu + (result[item] / column_count) * (result[item] / column_count)
    return 1 - sum_fenshu


def proc_node(current_node):
    if len(current_node.dataset_1) <= min_row_count or len(current_node.dataset_2) <= min_row_count:
        return current_node
    # 设置左右子节点
    current_node.set_left(get_finall_node(current_node.dataset_1))
    current_node.set_right(get_finall_node(current_node.dataset_2))


def bianli_node(root):
    print(str(root.jini) + "   " + g_columns[root.index])
    if root.left:
        bianli_node(root.left)
    if root.right:
        bianli_node(root.right)

root = get_finall_node(sample_input)
proc_node(root)
bianli_node(root)

猜你喜欢

转载自blog.csdn.net/dasgk/article/details/81084590