Python实现决策树应用之判断鱼类与非鱼类

版权声明:学习交流为主,未经博主同意禁止转载,禁止用于商用。 https://blog.csdn.net/u012965373/article/details/83929654

代码模块一:DecisionTreePlot  

# -*- coding:utf-8 -*-
__author__ = 'yangxin_ryan'

import matplotlib.pyplot as plt
"""
定义文本框 和 箭头格式 
【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】
"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


class DecisionTreePlot(object):

    def get_num_leafs(self, my_tree):
        num_leafs = 0
        first_str = my_tree.keys()[0]
        second_dict = my_tree[first_str]
        for key in second_dict.keys():
            if type(second_dict[key]) is dict:
                num_leafs += self.get_num_leafs(second_dict[key])
            else:
                num_leafs += 1
        return num_leafs

    def get_tree_depth(self, my_tree):
        max_depth = 0
        first_str = my_tree.keys()[0]
        second_dict = my_tree[first_str]
        for key in second_dict.keys():
            if type(second_dict[key]) is dict:
                this_depth = 1 + self.get_tree_depth(second_dict[key])
            else:
                this_depth = 1
            if this_depth > max_depth:
                max_depth = this_depth
        return max_depth

    def plot_node(self, node_txt, center_pt, parent_pt, node_type):
        self.create_plot.ax1.annotate(node_txt, xy=parent_pt,  xycoords='axes fraction', xytext=center_pt,
                                textcoords='axes fraction', va="center", ha="center", bbox=node_type,
                                arrowprops=arrow_args)

    def plot_mid_text(self, cntr_pt, parent_pt, txt_string):
        x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
        y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
        self.create_plot.ax1.text(x_mid, y_mid, txt_string, va="center", ha="center", rotation=30)

    def plot_tree(self, my_tree, parent_pt, node_txt):
        num_leafs = self.get_num_leafs(my_tree)
        cntr_pt = (self.plot_tree.xOff + (1.0 + float(num_leafs)) / 2.0 / self.plot_tree.totalW, self.plot_tree.yOff)
        self.plot_mid_text(cntr_pt, parent_pt, node_txt)
        first_str = my_tree.keys()[0]
        self.plot_node(first_str, cntr_pt, parent_pt, decisionNode)
        second_dict = my_tree[first_str]
        self.plot_tree.yOff = self.plot_tree.yOff - 1.0 / self.plot_tree.totalD
        for key in second_dict.keys():
            if type(second_dict[key]) is dict:
                self.plot_tree(second_dict[key], cntr_pt, str(key))
            else:
                self.plot_tree.xOff = self.plot_tree.xOff + 1.0 / self.plot_tree.totalW
                self.plot_node(second_dict[key], (self.plot_tree.xOff, self.plot_tree.yOff), self.cntr_pt, self.leaf_node)
                self.plot_mid_text((self.plot_tree.xOff, self.plot_tree.yOff), self.cntr_pt, str(key))
        self.plot_tree.yOff = self.plot_tree.yOff + 1.0 / self.plot_tree.totalD

    def create_plot(self, in_tree):
        fig = plt.figure(1, facecolor='green')
        fig.clf()
        axprops = dict(xticks=[], yticks=[])
        self.create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
        self.plot_tree.totalW = float(self.get_num_leafs(in_tree))
        self.plot_tree.totalD = float(self.get_tree_depth(in_tree))
        self.plot_tree.xOff = -0.5 / self.plot_tree.totalW
        self.plot_tree.yOff = 1.0
        self.plot_tree(in_tree, (0.5, 1.0), '')
        plt.show()

    def retrieve_tree(self, i):
        list_of_trees = [
            {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
            {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
        ]
        return list_of_trees[i]

代码模块二、DescionTreeApp

# -*- coding:utf-8 -*-
__author__ = 'yangxin_ryan'

import operator
from math import log
from src.descion_tree.decision_tree_plot import DecisionTreePlot as dtPlot
import pickle
import copy


class DescionTreeApp(object):

    def create_data_set(self):
        data_set = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return data_set, labels

    def calc_shannon_ent(self, data_set):
        num_entries = len(data_set)
        label_counts = {}
        for feat_vec in data_set:
            current_label = feat_vec[-1]
            if current_label not in label_counts.keys():
                label_counts[current_label] = 0
            label_counts[current_label] += 1
        shannon_ent = 0.0
        for key in label_counts:
            prob = float(label_counts[key]) / num_entries
            shannon_ent -= prob * log(prob, 2)
        return shannon_ent

    def split_data_set(self, data_set, index, value):
        ret_data_set = []
        for feat_vec in data_set:
            if feat_vec[index] == value:
                reduced_feat_vec = feat_vec[:index]
                reduced_feat_vec.extend(feat_vec[index+1:])
                ret_data_set.append(reduced_feat_vec)
        return ret_data_set

    def choose_best_feature_to_split(self, data_set):
        num_features = len(data_set[0]) - 1
        base_entropy = self.calc_shannon_ent(data_set)
        best_info_gain, best_feature = 0.0, -1
        for i in range(num_features):
            feat_list = [example[i] for example in data_set]
            unique_vals = set(feat_list)
            new_entropy = 0.0
            for value in unique_vals:
                sub_data_set = self.split_data_set(data_set, i, value)
                prob = len(sub_data_set)/float(len(data_set))
                new_entropy += prob * self.calc_shannon_ent(sub_data_set)
            info_gain = base_entropy - new_entropy
            if info_gain > best_info_gain:
                best_info_gain = info_gain
                best_feature = i
        return best_feature

    def majority_cnt(self, class_list):
        class_count = {}
        for vote in class_list:
            if vote not in class_count.keys():
                class_count[vote] = 0
            class_count[vote] += 1
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_count[0][0]

    def create_tree(self, data_set, labels):
        class_list = [example[-1] for example in data_set]
        if class_list.count(class_list[0]) == len(class_list):
            return class_list[0]
        if len(data_set[0]) == 1:
            return self.majority_cnt(class_list)
        best_feat = self.choose_best_feature_to_split(data_set)
        best_feat_label = labels[best_feat]
        my_tree = {best_feat_label: {}}
        del(labels[best_feat])
        feat_values = [example[best_feat] for example in data_set]
        unique_vals = set(feat_values)
        for value in unique_vals:
            sub_labels = labels[:]
            my_tree[best_feat_label][value] = self.create_tree(self.split_data_set(data_set, best_feat, value), sub_labels)
        return my_tree

    def classify(self, input_tree, feat_labels, test_vec):
        first_str = list(input_tree.keys())[0]
        second_dict = input_tree[first_str]
        feat_index = feat_labels.index(first_str)
        key = test_vec[feat_index]
        value_of_feat = second_dict[key]
        if isinstance(value_of_feat, dict):
            class_label = self.classify(value_of_feat, feat_labels, test_vec)
        else:
            class_label = value_of_feat
        return class_label

    def store_tree(self, input_tree, filename):
        fw = open(filename, 'wb')
        pickle.dump(input_tree, fw)
        fw.close()
        with open(filename, 'wb') as fw:
            pickle.dump(input_tree, fw)

    def grab_tree(self, filename):
        fr = open(filename, 'rb')
        return pickle.load(fr)

    # 应用测试一、判断鱼类与非鱼类
    def app_fish(self):
        my_dat, labels = self.create_data_set()
        my_tree = self.create_tree(my_dat, copy.deepcopy(labels))
        dtPlot.create_plot(my_tree)

  


if __name__ == "__main__":
    app = DescionTreeApp()
    app.app_fish()

猜你喜欢

转载自blog.csdn.net/u012965373/article/details/83929654