决策树理解与python实现

版权声明:本文为博主原创文章,转载需要注明源地址 https://blog.csdn.net/qq_33512078/article/details/78985304

代码实现请直接移步博文末尾

在机器学习领域,决策树是用于数据分类、预测的模型。决策树算法通过分析训练集的各个数据特征的不同,由原始数据集构造出一个树形结构,比如我们分析一封邮件是否为垃圾邮件时,可以根据发送方域名、邮件主题等方式区分邮件是否为垃圾邮件,新数据通过使用构造出的决策树模型来进行预测。
决策树算法的关键主要是寻找一个最合适的数据特征将数据集区分开来。我使用以下数据进行的测试:

使用以下数据区分西瓜是否为好瓜

编号 色泽 根蒂 敲声 纹理 脐部 触感 好瓜
1 青绿 蜷缩 浊响 清晰 凹陷 硬滑
2 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑

ps: 以上为用到的部分数据

决策树划分数据集的方法伪代码:

function split_data()

if 所有数据都属于同一个分类:
    return 分类名称
else:
    寻找划分数据集的最好特征
    划分数据集
    创建分支节点
    for 每个划分的子集:
        调用函数split_data(),并将返回结果增加到分支节点中
    return 分支节点

上述示例数据集中,是否为好瓜就是分类名称,色泽、根蒂则是用于划分数据集的特征。

所以决策树的基本思想就是不断递归寻找最好的特征,然后使用特征划分数据集成多个,使得子数据集的熵最小。

上面我们引入了“熵”这个词,这个词用于代表数据的混乱程度(在化学等领域也有这个词,意思差不多),我们上面所说的使子数据集们的熵最小,意思也就是使用一个特征划分数据集,使每个子数据集拥有的分类尽可能少。

在这里,我们将熵定义为信息的期望值(ps:这是香农定义的),在计算熵之前,我们首先得知道信息的定义。如果待分类的事物可能划分在多个分类中,这符号 xi 的信息定义为:

l(xi)=log2p(xi)

其中 p(xi) 是选择该分类的概率

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,有下面的公式:

H=i=1np(xi)log2p(xi)

同时我们引入一个词“信息增益”,信息增益指的是在划分数据集之前之后信息发生的变化。简单来说就是使用某个属性划分数据集之后的信息熵与划分之前的信息熵之差。信息增益越大,表明数据的混乱程度减少的越多,说明更适合使用该属性来划分当前的数据集。

所以构建决策树的主要思想就是,我们通过每次遍历所有属性,尝试划分数据集,寻找信息增益最大的划分,不断递归划分数据,直到所有属性被用完(下面我的代码使用的ID3算法,每次划分消耗一个属性,当然也有不消耗属性的算法)或者各子数据集内部数据类型相同。

根据西瓜数据集生成的决策树样式:
西瓜决策树

python代码实现(ID3算法):

# -*- coding:utf-8 -*-
from math import log


def calc_shannon_ent(dataSet):
    """
    计算香农熵
    :param dataSet: 待计算的数据集
    :return: 香农熵
    """
    data_nums = len(dataSet)
    label_count = {}
    for i in dataSet:
        label = i[-1]
        if label not in label_count.keys():
            label_count[label] = 0
        label_count[label] += 1
    shannon_ent = 0.0
    for i in label_count.keys():
        prob = float(label_count[i])/data_nums
        shannon_ent -= prob*log(prob, 2)
    return shannon_ent


def split_data_set(dataSet, index, value):
    """
    消耗指定属性,划分数据集,返回指定值的数据
    :param dataSet: 待划分的数据集
    :param index: 指定属性所在的索引
    :param value: 返回的数据集对应属性的值
    :return: new_data_set: 新数据集
    """
    new_data_set = []
    for i in dataSet:
        if i[index] == value:
            new_data = i[:index]  # 消耗指定的数据,将数据集划分
            new_data.extend(i[index+1:])
            new_data_set.append(new_data)
    return new_data_set


def choose_best_split_method(dataSet):
    """
    选择最好的用于分类的属性,返回属性索引
    :param dataSet: 数据集
    :return: 用于分类的索引
    """
    dataSet_num = len(dataSet)
    attribute_num = len(dataSet[0]) - 1  # 数据集的属性个数
    base_entropy = calc_shannon_ent(dataSet)  # 未分类的香农熵
    best_information_gain = 0.0  # 最好的信息增益的值
    best_index = -1  # 最好信息增益对应的属性索引
    for i in range(attribute_num):
        attributes = [one_data[i] for one_data in dataSet]
        attributes = set(attributes)
        entropy = 0.0  # 香农熵
        for value in attributes:
            sub_data_set = split_data_set(dataSet, i, value)
            prob = len(sub_data_set)/float(dataSet_num)
            entropy += prob*calc_shannon_ent(sub_data_set)  # 对划分的两个数据集香农熵求均值
        information_gain = base_entropy - entropy  # 求此次划分的信息增益
        if information_gain > best_information_gain:
            best_information_gain = information_gain
            best_index = i
    return best_index


def majority_cnt(class_list):
    """
    返回出现最多次数的分类名称
    :param class_list: 数据集中所有目标值分类的列表
    :return: 出现次数最多的目标分类名称
    """
    class_count = {}
    for i in class_list:
        if i not in class_count.keys():
            class_count[i] = 0
        class_count[i] += 1
    sorted_class_count = sorted(class_count.iteritems(), key=lambda x: x[1])
    return sorted_class_count[0][0]


def create_tree(dataSet, labels):
    """
    创建决策树
    :param dataSet: 数据集
    :param labels: 所有的属性名称集
    :return: 决策树字典
    """
    class_list = [item[-1] for item in dataSet]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]
    if len(dataSet[0]) == 1:  # 每次决策树划分分支都会消耗一个属性,最后只剩一个目标属性,说明无法划分分支了
        return majority_cnt(class_list)
    best_index = choose_best_split_method(dataSet)
    best_label = labels[best_index]
    mytree = {best_label: {}}
    del labels[best_index]
    all_values = [item[best_index] for item in dataSet]
    all_values = set(all_values)
    for i in all_values:
        sub_labels = labels[:]
        mytree[best_label][i] = create_tree(split_data_set(dataSet, best_index, i), sub_labels)
    return mytree


def classify(decision_tree, labels, one_data):
    """
    使用决策树进行分类
    :param decision_tree: 决策树
    :param labels: 属性名称列表
    :param one_data: 需要分类的数据
    :return: 分类结果
    """
    first = decision_tree.keys()[0]
    second_dict = decision_tree[first]
    first_index = labels.index(first)
    for i in second_dict:
        if one_data[first_index] == i:
            if type(second_dict[i]).__name__ == 'dict':
                class_label = classify(second_dict[i], labels, one_data)
            else:
                class_label = second_dict[i]
    return class_label


if __name__ == '__main__':
    # 使用数据集中的前部分数据构造决策树,使用最后一条数据检测决策树的正确性
    dataSet = []
    with open('watermelon.txt', 'r') as fp:
        data = fp.readline()
        while data:
            dataSet.append(data.split(' ')[1:])
            data = fp.readline()
    test_set = dataSet[-1]
    dataSet = dataSet[:-1]
    tree = create_tree(dataSet, [u'色泽', u'根蒂', u'敲声', u'纹理', u'脐部', u'触感'])
    print classify(tree, [u'色泽', u'根蒂', u'敲声', u'纹理', u'脐部', u'触感'], test_set)

本文章仅为个人理解,如有错误欢迎批评指正,转载请注明出处

猜你喜欢

转载自blog.csdn.net/qq_33512078/article/details/78985304
今日推荐