版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27668313/article/details/79063196
先上代码,理论现在没空写(主要是编辑公式太麻烦)。代码适合于Python 3.X,没有剪枝部分,实现的是分类树,不是回归树。
from numpy import * import operator import numpy as np def loadDataSet(): dataSet = [ ['constant', 'hair', 'true', 'false', 'false', 'false', 'true', 'false', 'mammal'], ['cold_blood', 'scale', 'false', 'true', 'false', 'false', 'false', 'true', 'reptile'], ['cold_blood', 'scale', 'false', 'true', 'false', 'true', 'false', 'false', 'fish'], ['constant', 'hair', 'true', 'false', 'false', 'true', 'false', 'false', 'mammal'], ['cold_blood', 'none', 'false', 'true', 'false', 'sometime', 'true', 'true', 'amphibious'], ['cold_blood', 'scale', 'false', 'true', 'false', 'false', 'true', 'false', 'reptile'], ['constant', 'hair', 'true', 'false', 'true', 'false', 'true', 'false', 'mammal'], ['constant', 'skin', 'true', 'false', 'false', 'false', 'true', 'false', 'mammal'], ['cold_blood', 'scale', 'true', 'false', 'false', 'true', 'false', 'false', 'fish'], ['cold_blood', 'scale', 'false', 'true', 'false', 'sometime', 'true', 'false', 'reptile'], ['constant', 'bristle', 'true', 'false', 'false', 'false', 'true', 'true', 'mammal'], ['cold_blood', 'scale', 'false', 'true', 'false', 'true', 'false', 'false', 'fish'], ['cold_blood', 'none', 'false', 'true', 'false', 'sometime', 'true', 'true', 'amphibious']] features = ["temperature", "cover", "viviparity", "egg", "fly", "water", "leg", "hibernate"] return dataSet, features def calGini(dataSet): numEntries = len(dataSet) labelCounts={} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 gini=1 for label in labelCounts.keys(): prop=float(labelCounts[label])/numEntries gini -= prop*prop return gini def splitDataSet(dataSet, axis, values): retDataSet = [] for featVec in dataSet: for value in values: if featVec[axis] == value: reducedFeatVec = featVec[:axis] #剔除样本集 reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet def splitData(data, i, setValue): subData_left = [] subData_right = [] for sample in data: if sample[i] == setValue: reducedSample = sample[:i] # 删除样本的i特征数据 reducedSample.extend(sample[i+1:]) subData_left.append(reducedSample) else: subData_right.append(sample) return subData_left, subData_right def chooseBestFeatureToSplit(data): IntiGini = 1.0 # 初始化基尼系数 feat = 0 featval = 0 for i in range(len(data[0])-1): # 遍历data中每个特征 valueList = [sample[i] for sample in data] # 每个样本特征i 的取值 values = set(valueList) # 特征i 的所有不同值 # print('values: ', values) for value in values: subData_left, subData_right = splitData(data, i, value) Gini_left = calGini(subData_left) Gini_right = calGini(subData_right) prob_L = float(len(subData_left)/len(data)) prob_R = float(len(subData_right)/len(data)) Gini = prob_L * Gini_left + prob_R * Gini_right # 数据在特征i取值为value情况下的基尼系数 if Gini <= IntiGini: IntiGini = Gini feat = i # 最好的划分特征索引 featval = value # 最好的划分特征的最佳取值 return feat, featval # 返回分割特征的索引,和分割值 # 在分割特征下,将最好的分割值和其它的分割值分成两个列表 # featlabel是分割特征索引,featval是分割特征下的最好分割值 def splitFeatValue(data, featlabel, featval): featValues = [] bestval = [] for sample in data: featValues.append(sample[featlabel]) # print('data: ', data) # print('featValues: ', featValues) bestval.append(featval) # 最好的分割值 notbestval = [] # 分割特征的其它值 for value in featValues: if value != featval: notbestval.append(value) return bestval, notbestval # 都是列表 def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items, key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0]#所有的类别都一样,就不用再划分了 if len(dataSet) == 1: #如果没有继续可以划分的特征,就多数表决决定分支的类别 return majorityCnt(classList) bestFeat,bestBinarySplit = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] # print('bestFeatLabel: ', bestFeatLabel) # print('bestBinarySplit: ', bestBinarySplit) splitValue, notSplitValue = splitFeatValue(data, bestFeat, bestBinarySplit) # print('splitValue: ', splitValue) # print('notSplitValue: ', notSplitValue) if bestFeat == -1: return majorityCnt(classList) myTree = {bestFeatLabel: {}} data_left, data_right = splitData(dataSet, bestFeat, bestBinarySplit) subLabels1 = labels[:] # 拷贝防止其他地方修改 del (subLabels1[bestFeat]) myTree[bestFeatLabel][splitValue[0]] = createTree(data_left, subLabels1) subLabels2 = labels[:] # 写else是因为不知道该填什么值,因为把其它没用上的分割值写上就显得太长了 myTree[bestFeatLabel]['else'] = createTree(data_right, subLabels2) return myTree if __name__ == '__main__': data, feature = loadDataSet() Tree = createTree(data, feature) print(Tree)