首先这篇的格式可能会乱,markdown装上以后,有时候是用csdn原来的编辑器,有时候就变成了markdown编辑器,蒙。
更蒙的是,大牛的代码太飘逸了,有点看不懂,惯例先来原地址:https://blog.csdn.net/Snoopy_Yuan/article/details/68959025
核心步骤有一个思路很赞,就是求信息增益时,用数组存储按某属性分类后,类别样本数,从而便于对熵求和。
没有写注释的,就是实在坚持不住贴了源代码的,例如删去节点、k折验证、画树等,以后得重新看一遍。
另外,pycharm编译画树那段不能通过,write_png函数会报错
‘GraphViz’s executables not found’)
pydotplus.graphviz.InvocationException: GraphViz’s executables not found
估计是graphviz的环境变量设置有关,但mac系统不知道怎么整,留坑。
感觉自己孕傻了,小柚子,努力比你老母亲聪明一些啊
上代码
分两个程序,主程序entropydecisiontree
import pandas as pd
#data_file_encode="gb18030" #gb18030支持汉字和少数民族字符,是一二四字节变长编码。这么用的时候with open需要增加encoding参数,但会报错gb18030不能解码
# with open相当于打开文件,保存成str对象,如果出错则关闭文件。参数r表示只读
with open("/Users/huatong/PycharmProjects/Data/watermelon_33.csv",mode="r") as data_file:
df=pd.read_csv(data_file)
#用seaborn画图看关联性
import matplotlib as mpl
import matplotlib.pyplot as plt #plt类似matlab绘制图表
import seaborn as sns
mpl.rcParams['font.sans-serif']=['SimHei'] #定义中文字体,用于在图上显示中文,原程序的Droid Sans Fallback字体显示不出来
mpl.rcParams['axes.unicode_minus']=False #坐标负号不显示
sns.set_context("poster") #图像大小
#f1=plt.figure(1) #figure(num)是图像编号或名称
#sns.FacetGrid(df,hue="好瓜",size=5).map(plt.scatter,"密度","含糖率").add_legend() #facetgrid用于绘制变量之间关系图,hue是因变量map是自变量
#plt.show()
#f2=plt.figure(2)
#sns.plt.subplot(221) 源程序创建了四个子图,分别编号
#sns.swarmplot(x="纹理",y="含糖率",hue="好瓜",data=df) #swarmplot分簇散点图,增加分类轴上的随机抖动,并且避免重叠,然而要求一个变量是数字的
#plt.show()
import decision_tree
decision_tree.TreeGenerate(df)
accuracy_scores = []
# k-folds cross prediction
n = len(df.index)
k = 5
for i in range(k):
m = int(n / k)
test = []
for j in range(i * m, i * m + m):
test.append(j)
df_train = df.drop(test)
df_test = df.iloc[test]
root = decision_tree.TreeGenerate(df_train) # generate the tree
# test the accuracy
pred_true = 0
for i in df_test.index:
label = decision_tree.Predict(root, df[df.index == i])
if label == df_test[df_test.columns[-1]][i]:
pred_true += 1
accuracy = pred_true / len(df_test.index)
accuracy_scores.append(accuracy)
# print the prediction accuracy result
accuracy_sum = 0
print("accuracy: ", end="")
for i in range(k):
print("%.3f " % accuracy_scores[i], end="")
accuracy_sum += accuracy_scores[i]
print("\naverage accuracy: %.3f" % (accuracy_sum / k))
# dicision tree visualization using pydotplus.graphviz
root = decision_tree.TreeGenerate(df)
decision_tree.DrawPNG(root, "decision_tree_ID3.png")
负责生成树、画树的decision_tree.py
#被主程序执行treeGenerate时候调用,def用于定义函数
#节点类,包含①当前节点的属性,例如纹理清晰? ②节点所属分类,只对叶子节点有效 ③向下划分的属性取值例如色泽乌黑青绿浅白
class Node(object): #新式类
def __init__(self,attr_init=None,label_init=None,attr_down_init={}): #注意类的特殊函数前后有两个下划线
self.attr=attr_init
self.label=label_init
self.attr_down=attr_down_init
#主函数,输入参数为数据集,输出参数为决策树根节点Node
def TreeGenerate(df):
new_node=Node(None,None,{})
label_arr=df[df.columns[-1]] #好瓜这列数值,df.columns[-1]是最后一列
label_count=NodeLabel(label_arr)
if label_count: #类别统计结果不为空
new_node.label=max(label_count,key=label_count.get) #取类别数目最多的类,get是返回键值
#如果样本全属于同一类别则直接返回叶节点,或如果样本属性集A为空则返回叶节点并标记类别为类别数最多的类,但如果样本属性取值相同怎么处理?
if len(label_count)==1 or len(label_arr)==0:
return new_node
#选择最优划分属性
new_node.attr,div_value=OptAttr(df)
#如果属性值为空,删除当前属性再递归
if div_value==0:
value_count=ValueCount(df[new_node.attr])
for value in value_count:
df_v=df[df[new_node.attr].isin([value])]
dv_v=df_v.drop(new_node.attr,1)
new_node.attr_down[value]=TreeGenerate(df_v)
else:
value_l="<=%.3f"%div_value
value_r=">%.3f"%div_value
df_v_l=df[df[new_node.attr]<=div_value]
df_v_r=df[df[new_node.attr]>div_value]
new_node.attr_down[value_l] = TreeGenerate(df_v_l)
new_node.attr_down[value_r] = TreeGenerate(df_v_r)
return new_node
#统计样本包含的类别和每个分类的个数,输入参数是分类标签序列,输出序列中包含的类别和各类别总数
def NodeLabel(label_arr):
label_count={}
for label in label_arr:
if label in label_count: label_count[label]+=1
else:label_count[label]=1
return label_count
#寻找最优划分属性,输入参数为数据集,输出参数为属性opt_attr和划分取值div_value,div_value对离散变量取值为0,对连续变量取实际值
def OptAttr(df):
info_gain=0
for attr_id in df.columns[1:-1]:
info_gain_tmp,div_value_tmp=InfoGain(df,attr_id)
if info_gain_tmp>info_gain:
info_gain=info_gain_tmp
opt_attr=attr_id
div_value=div_value_tmp
return opt_attr,div_value
#计算信息增益,输入参数为数据集、属性值,输出参数为信息增益info_gain和划分取值div_value
def InfoGain(df,index):
info_gain=InfoEnt(df.values[:,-1]) #好瓜列的信息熵
div_value=0 #划分数值
n=len(df[index]) #样本数
#对连续值变量
if df[index].dtype==(float,int):
sub_info_ent={} #存储划分数值和各子分类的?
df=df.sort_values([index],ascending=1) #按属性这列排序,升序
df=df.reset_index(drop=True) #sort后索引变化了,需要还原索引
data_arr=df[index]
label_arr=df[df.columns[-1]]
for i in range(n-1):
div=(data_arr[i]+data_arr[i+1])/2 #连续值属性的划分点集合
sub_info_ent[div] = ( (i+1) * InfoEnt(label_arr[0:i+1]) / n ) \
+ ( (n-i-1) * InfoEnt(label_arr[i+1:-1]) / n )
div_value,sub_info_ent_max=min(sub_info_ent.items(),key=lambda x:x[1]) #最大信息增益即最小的Ent(D),lambda用于命名匿名函数
info_gain-=sub_info_ent_max
#对离散值变量
else:
data_arr=df[index]
label_arr=df[df.columns[-1]]
value_count=ValueCount(data_arr)
for key in value_count:
key_label_arr=label_arr[data_arr==key]
info_gain-=value_count[key]*InfoEnt(key_label_arr)/n
return info_gain,div_value
#计算某属性划分的信息增益,输入属性队列,输出信息增益
def InfoEnt(label_arr):
try:
from math import log2
except ImportError:
print("modle math.log2 not found")
ent=0
n=len(label_arr)
label_count=NodeLabel(label_arr)
for key in label_count:
ent-=(label_count[key]/n)*log2(label_count[key]/n)
return ent
#根据输入参数属性值区分后,各分类的样本个数
def ValueCount(data_arr):
value_count={}
for label in data_arr:
if label in value_count: value_count[label]+=1
else: value_count[label]=1
return value_count
#根据根节点预测
def Predict(root, df_sample):
try:
import re # using Regular Expression to get the number in string
except ImportError:
print("module re not found")
while root.attr != None:
# continuous variable
if df_sample[root.attr].dtype == (float, int):
# get the div_value from root.attr_down
for key in list(root.attr_down):
num = re.findall(r"\d+\.?\d*", key)
div_value = float(num[0])
break
if df_sample[root.attr].values[0] <= div_value:
key = "<=%.3f" % div_value
root = root.attr_down[key]
else:
key = ">%.3f" % div_value
root = root.attr_down[key]
# categoric variable
else:
key = df_sample[root.attr].values[0]
# check whether the attr_value in the child branch
if key in root.attr_down:
root = root.attr_down[key]
else:
break
return root.label
def DrawPNG(root, out_file):
import graphviz
'''
visualization of decision tree from root.
@param root: Node, the root node for tree.
@param out_file: str, name and path of output file
'''
try:
from pydotplus import graphviz
except ImportError:
print("module pydotplus.graphviz not found")
g = graphviz.Dot() # generation of new dot,这里要安装graphviz模块
TreeToGraph(0, g, root)
g2 = graphviz.graph_from_dot_data(g.to_string())
g2.write_png(out_file)
def TreeToGraph(i, g, root):
'''
build a graph from root on
@param i: node number in this tree
@param g: pydotplus.graphviz.Dot() object
@param root: the root node
@return i: node number after modified
# @return g: pydotplus.graphviz.Dot() object after modified
@return g_node: the current root node in graphviz
'''
try:
from pydotplus import graphviz
except ImportError:
print("module pydotplus.graphviz not found")
if root.attr == None:
g_node_label = "Node:%d\n好瓜:%s" % (i, root.label)
else:
g_node_label = "Node:%d\n好瓜:%s\n属性:%s" % (i, root.label, root.attr)
g_node = i
g.add_node(graphviz.Node(g_node, label=g_node_label))
for value in list(root.attr_down):
i, g_child = TreeToGraph(i + 1, g, root.attr_down[value])
g.add_edge(graphviz.Edge(g_node, g_child, label=value))
return i, g_node
今天收到了第一条留言,发现共享数据文件还是有需求的。手打数据不容易,希望有第一个赞~
源数据来自《机器学习》第84页表4.3,西瓜数据3.0
链接:https://pan.baidu.com/s/1MdY3j6litrX2o671wKQf8g 密码:0t82