代码参照自https://blog.csdn.net/csqazwsxedc/article/details/65697652
里面写得很详细,数据是来自西瓜书的西瓜数据2.0
这里是个人的一些学习笔记https://blog.csdn.net/aaalswaaa1/article/details/83024513
from math import log
import operator
from FileUtil import FileUtil
#计算熵值
def ent(dataSet):
numData = len(dataSet)
labelCounts = {}
for i in dataSet:
tempLabel = i[-1]
if tempLabel not in labelCounts.keys():
labelCounts[tempLabel] = 0
labelCounts[tempLabel] += 1
ent = 0
for i in labelCounts:
p = float(labelCounts[i])/numData
ent -= p*log(p, 2)
return ent
#按特征划分数据,即选出属性axis中值为value的样本
def splitData(dataSet, axis, value):
ratDataSet = []
for i in dataSet:
if i[axis] == value:
temp = i[:axis]
temp.extend(i[axis+1:])
ratDataSet.append(temp)
return ratDataSet
# 选区最优决策特征
def chooseBestFeature(dataSet):
numFeature = len(dataSet[0])-1
bestEnt = ent(dataSet)
bestFeature = -1
bestGain = 0
for i in range(numFeature):
featureList = [j[i] for j in dataSet]
attributeSet = set(featureList)
newEnt = 0
for attribute in attributeSet:
subDataSet = splitData(dataSet, i, attribute)
p = float(len(subDataSet))/len(dataSet)
newEnt += p*ent(subDataSet)
gain = bestEnt - newEnt
if gain > bestGain:
bestGain = gain
bestFeature = i
return bestFeature
# 计算出占多数的一类
def majorityCnt(classList):
classCounts = {}
for value in classList:
if value not in classList.keys():
classCounts[value] = 0
classList[value] += 1
sortClassCounts = sorted(classCounts.items(),operator.itemgetter(1),reverse=True)
return sortClassCounts[0][0]
# 生成决策树
def createTree(dataSet,labels):
ClassList = [i[-1] for i in dataSet]
if ClassList.count(ClassList[0])==len(dataSet):
return ClassList[0]
if len(dataSet[0])==1:
return majorityCnt(ClassList)
bestFeature = chooseBestFeature(dataSet)
bestLabel = labels[bestFeature]
dtree = {bestLabel:{}}
del(labels[bestFeature])
featVec = [i[bestFeature] for i in dataSet]
attribute = set(featVec)
for value in attribute:
subLabel = labels[:]
dtree[bestLabel][value] = createTree(splitData(dataSet,bestFeature,value),subLabel)
return dtree
def createDataSet():
dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']]
labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
return dataSet,labels
if __name__=='__main__':
dataSet, labels=createDataSet()
print(createTree(dataSet, labels))