机器学习系列之决策树

最近想把每个机器学习的算法,重新学习一遍。最好能自己编写一遍,但是一方面编程能力欠缺,另一方面时间有限。所以大本分代码都是跟着别人的技术博客,照葫芦画瓢。
无论是编程能力,还是机器学习算法,都有待进一步提升。请注意下面的代码不完整,完整代码请参照下面分享的大牛的技术博客。


#!/usr/bin/env python
# -*- coding:utf-8 -*-
__author__ = 'Great'

"""
输入:数据集
输出:决策树(分类结果)

#伪代码
def 创建决策树:
    
    if 数据集样本分类一致:
       创建带类标签的叶子节点
    else:
       寻找划分数据集,信息熵增益最大的特征
       据此划分数据集
       for 每个划分后的数据集:
           创建之树(递归)

def 加载数据集
def 计算熵
def 数据集划分
def 根据熵增益选择最佳划分
def 递归构建决策树
def 样本分类
def matplotlib 显示
def 决策树存储
"""

"""计算信息熵
H(x) = -∑[p(x)log2(p(x))]
熵是基于每种类别的概率计算
"""
from math import log
def entropy_cal(data):
    label = {}
    for one in data:
        label_data = one[-1]
        if label_data not in label.keys():
            label[label_data] =0
        label[label_data] += 1

    length = len(data)
    h_entropy = 0
    for item in label:
        px = label[item]/length
        h_entropy -= float(px)*log(px, 2)

    return h_entropy

"""数据集"""
def get_data():
    dataset = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    return dataset

"""测试"""

data = get_data()
test = entropy_cal(data)
print(test)

#数据集的划分
def splitdata(data, axis, value):
    next_data = []
    for one in data:
        if one[axis] == value:
            data_next = data[:axis]
            data_next.extend(data[axis+1:])
            next_data.append(data_next)
    return next_data
#计算信息熵增益
G = H(D) - H(Di|xi)
H(Di|xi) = -∑(Di/D)*∑[p(x)log2(p(x))]

#最佳数据集划分特征
def bestfeature(data):

    best_gain = 0
    base_entropy = entropy_cal(data)
    best_feat = -1

    h_f_entropy = 0
    length = len(data[0]) - 1
    for i in range(length):
        feat_list = [item[i] for item in data]
        unique_feat = set(feat_list)

        new_entropy = 0
        for value in unique_feat:
            sub_data = splitdata(data, i, value)
            prob = len(sub_data)/float(len(data))
            new_entropy += prob*entropy_cal(sub_data)

        get_gain = base_entropy - new_entropy
        if (get_gain > best_gain):
            best_gain = get_gain
            best_feat = i
    return best_feat

'''测试2'''

#data = get_data()
feat = bestfeature(data)
#print(feat)

不完整代码,未完成,待续。。。
参考:统计学习方法
https://www.cnblogs.com/muchen/p/6141978.html
https://www.cnblogs.com/luozeng/p/8604997.html
https://www.cnblogs.com/lianjiehere/p/6862890.html
 

猜你喜欢

转载自blog.csdn.net/weixin_41512727/article/details/80935832