《机器学习实战》讲的CART算法是回归树和模型树。没有给出具体的分类树算法。
于是参考ID3的算法,基于CART的思想实现分类树的算法(不包含剪枝)。
原理网上太多了,直接贴代码和注释。
import numpy as np
class CARTClassifier:
def caclGini(self, dataSet):
'''
计算基尼指数
:param dataSet: 包含标签类的数据集 m*(n+1)。 离散特征用int数据类型,连续特征用float数据类型。 如果包含离散特征,则必须是二维列表类型。(因为array的数据类型都是一样的)
:return: 数据集的基尼指数
'''
targets = [example[-1] for example in dataSet] # 获取标签列表
d = {} # 存储标签及其相应的数量
gini = 1
for target in targets:
if target not in d:
d[target] = 0
d[target] += 1
for target in d.keys():
prob = d[target] / len(dataSet)
gini -= prob**2 # 计算基尼指数
return gini
def splitDataSet(self, dataSet, feature, value):
'''
按照给定的特征和特征值切分数据集。这里根据是否连续型数据或离散型数据,采用不同的二分法
:param dataSet: 同上
:param feature: 待切分的特征
:param value: 选取的特征值
:return: 二分切割后的2个子数据集
'''
left, right = [], []
if isinstance(value, int): # 离散型
for i in range(len(dataSet)):
reduceSet = dataSet[i][: feature]
reduceSet.extend(dataSet[i][feature + 1:])
if dataSet[i][feature] == value: # 离散型数据根据值是否等于给定值划分数据集
left.append(reduceSet)
if dataSet[i][feature] != value:
right.append(reduceSet)
if isinstance(value, float): # 连续型
for i in range(len(dataSet)):
reduceSet = dataSet[i][: feature]
reduceSet.extend(dataSet[i][feature + 1:])
if dataSet[i][feature] <= value: # 连续型数据根据值是否小于等于给定值划分数据集
left.append(reduceSet)
if dataSet[i][feature] > value:
right.append(reduceSet)
return left, right
def chooseBestFeature(self, dataSet):
'''
选出使基尼指数最小的特征 与 对应的特征值
:param dataSet: 同上
:return: 最佳分类特征与特征值
'''
n = len(dataSet[0]) - 1 # 获取特征的数量
splitGini, bestFeatrue, bestValue = None, None, None # 声明变量,防止“本地变量调用前被引用”的错误
minGini = np.inf # 一个无限大的正数
for featureIndex in range(n):
values = set([example[featureIndex] for example in dataSet]) # 获取所有可能的特征取值
for value in values:
left, right = self.splitDataSet(dataSet, featureIndex, value)
leftGini = self.caclGini(left) # 左子树
rightGini = self.caclGini(right) # 右子树
splitGini = (len(left) / len(dataSet)) * leftGini + (len(right) / len(dataSet)) * rightGini # 分割后的基尼指数
# print('特征索引:', featureIndex, '特征值:', value, '基尼指数:', splitGini)
if splitGini < minGini:
minGini = splitGini
bestFeatrue = featureIndex
bestValue = value
return bestFeatrue, bestValue
def classtarget(self, targets):
'''
辅助函数
:param targets: 类别列表
:return: 返回最终叶子中,数量最多的类别
'''
d = {}
for target in targets:
if target not in d:
d[target] = 0
d[target] += 1
return max(d, d.get) # 获取值最大的键
def createTree(self, dataSet, featureLabel):
'''
创建决策树
:param dataSet: 同上
:param featureLabel: 特征索引-特征含义 对照表
:return: 二叉决策树
'''
targets = [example[-1] for example in dataSet]
if len(set(targets)) == 1: # 只包含一个类别
return targets[0]
if len(dataSet[0]) == 1: # 没有可分的特征
return self.classtarget(targets)
bestFeatureIndex, bestValue = self.chooseBestFeature(dataSet)
featureLable_copy = featureLabel.copy() # 避免对源数据的修改
bestFeatureLabel = featureLable_copy[bestFeatureIndex]
if isinstance(bestValue, int): # 如果是离散型数据,删除对照表中的对应的索引,因为不会再用到。 如果是连续型的数据,则保留,因为可能再用到
del featureLable_copy[bestFeatureIndex]
mytree = {}
mytree['FeatLabel'] = bestFeatureLabel
mytree['FeatValue'] = bestValue
lSet, rSet = self.splitDataSet(dataSet, bestFeatureIndex, bestValue)
mytree['left'] = self.createTree(lSet, featureLable_copy)
mytree['right'] = self.createTree(rSet, featureLable_copy)
return mytree
def predict(self, tree, featureLabel, testvec):
'''
根据训练好的决策树,输出待测试样本的类别
:param tree: 训练好的决策树
:param featureLabel: 同上
:param testvec: 待测试的样本
:return: 预测类别
'''
if not isinstance(tree, dict):
return tree
bestFeatureIndex = featureLabel.index(tree['FeatLabel'])
value = testvec[bestFeatureIndex]
if isinstance(value, int):
if value == tree['FeatValue']:
return self.predict(tree['left'], featureLabel, testvec)
if value != tree['FeatValue']:
return self.predict(tree['right'], featureLabel, testvec)
if isinstance(value, float):
if value <= tree['FeatValue']:
return self.predict(tree['left'], featureLabel, testvec)
if value > tree['FeatValue']:
return self.predict(tree['right'], featureLabel, testvec)
# 以下是测试
dataSet = [[0, 1, 0.15, 'yes'],
[1, 1, 0.25, 'yes'],
[1, 0, 0.21, 'no'],
[0, 1, 0.45, 'no'],
[0, 1, 0.55, 'no']]
feature_label = ['No Surfacing', 'Flippers', 'other']
cart = CARTClassifier()
mytree = cart.createTree(dataSet, feature_label)
print(mytree)
print(cart.predict(mytree, feature_label, [1, 0, 0.35]))