周志华《机器学习》第四章决策树-编程尝试
一、导入需要用到的包
'''
import math
import re
from sklearn.externals.six import StringIO
from pydotplus import graphviz
'''
二、定义node类
'''
class Node(object):
def __init__(self, attr_init=None, label_init=None, attr_down_init={}):
self.attr = attr_init
self.label = label_init
self.attr_down = attr_down_init
'''
三、标签计数字典
'''
def tagDict(tag_attr):
tag_Dict = {}
for tag in tag_attr:
if tag in tag_Dict:
tag_Dict[tag] += 1
else:
tag_Dict[tag] = 1
return tag_Dict
'''
四、属性变量计数字典
'''
def predCount(pred_attr):
pred_Count = {}
for pred in pred_attr:
if pred in pred_Count:
pred_Count[pred] += 1
else:
pred_Count[pred] = 1
return pred_Count
'''
五、信息熵
'''
def Ent(tag_attr):
try:
import math
except ImportError:
print('module math not found')
ent_tag = 0
n = len(tag_attr)
tag_Dict = tagDict(tag_attr)
for tag in tag_attr:
ent_tag -= (tag_Dict[tag] / n) * math.log2(tag_Dict[tag] / n)
return ent_tag
'''
六、信息增益
'''
def Gain(df, predict, response):
gain = Ent(df[response])
div_point = 0
div_ent = {}
n = len(df[predict])
if df[predict].dtype == ('float64', 'int64'):
numeric_ent = {}
df = df.sort_values(by=[predict], ascending=1)
df = df.reset_index(drop=True)
pred_attr = df[predict]
tag_attr = df[response]
for i in range(1, n):
div_point = (pred_attr[i - 1] + pred_attr[i]) / 2
div_ent[div_point] = i * Ent(tag_attr[:i]) / n + (n - i) * Ent(tag_attr[i+1:]) / n
div_point, numeric_ent = min(div_ent.items(), key=lambda x: x[1])
gain -= numeric_ent
else:
pred_attr = df[predict]
tag_attr = df[response]
pred_Count = predCount(pred_attr)
for pred in pred_Count:
attr_attr = tag_attr[pred_attr == pred]
gain -= pred_Count[pred] * Ent(attr_attr) /n
return gain, div_point
'''
七、最优属性
'''
def getOptPred(df, response):
gain = 0
for pred in df.drop([response], axis=1):
gain_tmp, div_point_tmp = Gain(df, pred, response)
if gain_tmp > gain:
gain = gain_tmp
opt_pred = pred
div_point = div_point_tmp
return opt_pred, div_point
'''
八、主函数
'''
def Tree(df, response):
new_node = Node(None, None, {})
tag_attr = df[response]
tag_Dict = tagDict(tag_attr)
if tag_Dict: # assert the label_count isn's empty
new_node.label = max(tag_Dict, key=tag_Dict.get)
if len(tag_Dict) == 1 or len(tag_attr) == 0:
return new_node
new_node.attr, div_point = getOptPred(df, response)
if div_point == 0:
pred_Count = predCount(df[new_node.attr])
for attr in pred_Count:
df_attr = df[df[new_node.attr].isin([attr])]
df_attr = df_attr.drop([new_node.attr], axis=1)
new_node.attr_down[attr] = tree(df_attr, response)
else:
point_left = "<=%.3f" % div_point
point_right = ">=%.3f" % div_point
df_attr_left = df[df[new_node.attr] <= div_point]
df_attr_right = df[df[new_node.attr] > div_point]
new_node.attr_down[point_left] = tree(df_attr_left, response)
new_node.attr_down[point_right] = tree(df_attr_right, response)
return new_node
'''
九、预测函数
'''
def Predict(root, df_sample):
try :
import re # using Regular Expression to get the number in string
except ImportError :
print("module re not found")
while root.attr != None :
if df_sample[root.attr].dtype == ('float64', 'int64'):
for key in list(root.attr_down):
num = re.findall(r"\d+\.?\d*",key)
div_value = float(num[0])
break
if df_sample[root.attr].values[0] <= div_value:
key = "<=%.3f" % div_value
root = root.attr_down[key]
else:
key = ">%.3f" % div_value
root = root.attr_down[key]
else:
key = df_sample[root.attr].values[0]
if key in root.attr_down:
root = root.attr_down[key]
else:
break
return root.label
'''
十、代入数据
'''
import pandas as pd
data_file_encode = "gb18030" # the watermelon_3.csv is file codec type
with open(".../data/watermelon_3.csv", mode = 'r', encoding = data_file_encode) as data_file:
df = pd.read_csv(data_file)
n = len(df.index)
k = 10
for i in range(k):
m = int(n/k)
test = []
for j in range(i*m, i*m+m):
test.append(j)
df_train = df.drop(test)
df_test = df.iloc[test]
root = tree(df_train) # generate the tree
pred_true = 0
for i in df_test.index:
label = predict(root, df[df.index == i])
if label == df_test[df_test.columns[-1]][i]:
pred_true += 1
accuracy = pred_true / len(df_test.index)
accuracy_scores.append(accuracy)
accuracy_sum = 0
print("accuracy: ", end = "")
for i in range(k):
print("%.3f " % accuracy_scores[i], end = "")
accuracy_sum += accuracy_scores[i]
print("\naverage accuracy: %.3f" % (accuracy_sum/k))
'''
十一、结果
十次交叉验证结果
扫描二维码关注公众号,回复:
2871664 查看本文章
十二、数据
>
编号 色泽 根蒂 敲声 纹理 脐部 触感 密度 含糖率 好瓜
1 青绿 蜷缩 浊响 清晰 凹陷 硬滑 0.697 0.46 是
2 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 0.774 0.376 是
3 乌黑 蜷缩 浊响 清晰 凹陷 硬滑 0.634 0.264 是
4 青绿 蜷缩 沉闷 清晰 凹陷 硬滑 0.608 0.318 是
5 浅白 蜷缩 浊响 清晰 凹陷 硬滑 0.556 0.215 是
6 青绿 稍蜷 浊响 清晰 稍凹 软粘 0.403 0.237 是
7 乌黑 稍蜷 浊响 稍糊 稍凹 软粘 0.481 0.149 是
8 乌黑 稍蜷 浊响 清晰 稍凹 硬滑 0.437 0.211 是
9 乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 0.666 0.091 否
10 青绿 硬挺 清脆 清晰 平坦 软粘 0.243 0.267 否
11 浅白 硬挺 清脆 模糊 平坦 硬滑 0.245 0.057 否
12 浅白 蜷缩 浊响 模糊 平坦 软粘 0.343 0.099 否
13 青绿 稍蜷 浊响 稍糊 凹陷 硬滑 0.639 0.161 否
14 浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 0.657 0.198 否
15 乌黑 稍蜷 浊响 清晰 稍凹 软粘 0.36 0.37 否
16 浅白 蜷缩 浊响 模糊 平坦 硬滑 0.593 0.042 否
17 青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 0.719 0.103 否
十三、参考
(https://blog.csdn.net/snoopy_yuan/article/details/68959025)