K近邻/决策树实战 — 分类识别评估
文章目录
一、k-近邻简介
什么是近邻法:
K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:在特征空间中,如果一个样本附近的k个最近(即特征空间中最邻近)样本的大多数属于某一个类别,则该样本也属于这个类别。
如上图所示,有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。也就是说,现在, 我们不知道中间那个绿色的数据是从属于哪一类(蓝色小正方形or红色小三角形),下面,我们就要解决这个问题:给这个绿色的圆分类。
我们常说,物以类聚,人以群分,判别一个人是一个什么样品质特征的人,常常可以从他/她身边的朋友入手,所谓观其友,而识其人。我们不是要判别图1中那个绿色的圆是属于哪一类数据么,好说,从它的邻居下手。但一次性看多少个邻居呢?从图1中,你还能看到:
- 如果K=3,绿色圆点的最近的3个邻居是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
- 如果K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。
于此我们看到,当无法判定当前待分类点是从属于已知分类中的哪一类时,我们可以依据统计学的理论看它所处的位置特征,衡量它周围邻居的权重,而把它归为(或分配)到权重更大的那一类。这就是K近邻算法的核心思想。
KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
KNN 算法本身简单有效,它是一种 lazy-learning 算法,分类器不需要使用训练集进行训练,训练时间复杂度为0。KNN 分类的计算复杂度和训练集中的文档数目成正比,也就是说,如果训练集中文档总数为 n,那么 KNN 的分类时间复杂度为O(n)。
KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
K 近邻算法使用的模型实际上对应于对特征空间的划分。K 值的选择,距离度量和分类决策规则是该算法的三个基本要素:
- K 值的选择会对算法的结果产生重大影响。K值较小意味着只有与输入实例较近的训练实例才会对预测结果起作用,但容易发生过拟合;如果 K 值较大,优点是可以减少学习的估计误差,但缺点是学习的近似误差增大,这时与输入实例较远的训练实例也会对预测起作用,使预测发生错误。在实际应用中,K 值一般选择一个较小的数值,通常采用交叉验证的方法来选择最优的 K 值。随着训练实例数目趋向于无穷和 K=1 时,误差率不会超过贝叶斯误差率的2倍,如果K也趋向于无穷,则误差率趋向于贝叶斯误差率。
- 该算法中的分类决策规则往往是多数表决,即由输入实例的 K 个最临近的训练实例中的多数类决定输入实例的类别
- 距离度量一般采用 Lp 距离,当p=2时,即为欧氏距离,在度量之前,应该将每个属性的值规范化,这样有助于防止具有较大初始值域的属性比具有较小初始值域的属性的权重过大。
KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比。 该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。
该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。
实现 K 近邻算法时,主要考虑的问题是如何对训练数据进行快速 K 近邻搜索,这在特征空间维数大及训练数据容量大时非常必要。
k值选择通常是一个难以抉择的问题,可以通过经验结合数据集进行选择,也可以通过交叉验证(下文会提到)的方式对一组k值进行测试,选出其中得分最高的k值作为后续训练的k。k值选择一般为单数(避免投票时出现平票的尴尬局面)
一句话概括: 所谓近朱者赤近墨者黑,K-近邻法的核心就是选择与其距离最近的K个“邻居”。
三要素: k值选择、距离度量、分类决策规则
优缺点: 精度高、对异常值不敏感,无数据输入假定。但是计算复杂度高、空间复杂度高。
二、决策树简介
非数值特征的量化
-
名义特征:正交编码
– 例如颜色、形状、性别、职业、字符串中的字符等
-
序数特征:等同于名义特征处理或转化为数值特征
– 例如序号、分级,不能看作是欧氏空间中的数值
-
区间特征:通过设定阈值变成二值特征或序数特征
– 与研究目标之间的关系呈现出明显的非线性。取值是实数,可以比较大小,但是没有一个“自然的”零,比值没有意义
– 例如年龄、温度、考试成绩等
一个简化的树状决策过程例子
决策树(Decision Tree 是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
分类树(决策树)是一种十分常用的分类方法。它是一种监督学习,所谓监督学习就是给定一堆样本,每个样本都有一组属性和一个类别,这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。这样的机器学习就被称之为监督学习。
(以下实践内容具体流程见代码注释)
三、手写k-近邻法,完成电影分类任务
由于作为入门任务,所以这里的数据比较简单,具体的数据见实验结果的代码图.
实验代码
# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 21:18
# @IDE : PyCharm(2022.3.2) Python3.9.13
import os
import operator
import numpy as np
import matplotlib.pyplot as plt
def createDataSet():
"""创建数据集
创建四组二维特征与四组特征的标签,代表了4部电影中动作镜头与喜剧镜头出现次数。
:return: 数据集group、分类标签labels
"""
group = np.array([[1, 101], [5, 89], [108, 5], [115, 8]])
# print(group)
labels = ['喜剧片', '喜剧片', '动作片', '动作片']
return group, labels
def classify0(inX, dataSet, labels, k):
"""kNN算法的实现
:param inX: 新样本
:param dataSet: 已知样本
:param labels: 已知样本标签
:param k: 近邻样本数量
:return: 分类结果sortedClassCount[0][0]
"""
dataSetSize = dataSet.shape[0] # 读取数据集第一维长度
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # 对新样本构造数组
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
# 计算新样本与所有已知样本的欧氏距离
distances = sqDistances ** 0.5
sortedDistIndices = distances.argsort()
classCount = {
}
# 统计最近的前k的已知样本
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 统计前k个已知样本中最多的类别,作为新样本的类别
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def visualization():
"""
将坐标进行可视化展示
:return:
"""
# matplotlib画图中中文显示会有问题,需要这两行设置默认字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 0 数据准备
X = np.array([1, 5])
Y = np.array([101, 89])
X2 = np.array([108, 115])
Y2 = np.array([5, 8])
xmx = X.max() + 10
ymx = Y.max() + 10
# 1 设置x,y坐标轴的刻度显示范围
fig = plt.figure()
plt.xlim(xmin=0, xmax=120) # x轴的范围
plt.ylim(ymin=0, ymax=120) # y轴的范围
plt.xlabel('X')
plt.ylabel('Y')
plt.title('电影分类可视化') # 图的标题
# 2 画图显示
plt.scatter(X, Y, marker='o', alpha=1, color="blue", label='喜剧片')
plt.scatter(X2, Y2, marker='x', alpha=1, color="green", label='动作片')
plt.scatter(101, 20, marker='*', alpha=1, color="red", label='测试样本')
plt.legend() # label='类别A' 图中显示
# 3 保存图像
path = os.getcwd() # 获取当前的工作路径
fileName = "001"
filePath = path + "\\" + fileName + ".png"
plt.savefig(filePath, dpi=600) # dpi越大,图像越清晰,当然图像所占的存储也大
plt.show() # 显示图像在保存后调用,否则会使保存的图片为空白
if __name__ == '__main__':
group, labels = createDataSet() # 调用createDataSet()创建数据集
test = [101, 20] # 测试样本
test_class = classify0(test, group, labels, 3) # 执行kNN分类
print('本电影分类为: ' + test_class)
visualization() # 可视化展示
实验结果
结果可视化:
由图片可直观看到测试样本更靠近动作片,所以分类也为动作片,分类结果正确。
四、手写k-近邻法,完成业务员评估任务
数据集说明
原始数据集示例如下:
文件每一行为一个样本,共有4列,以制表符分割,前3列为属性(飞行里程数、联系客户次数、促成交易数),第4列为标签(公司对业务员的评价),总共有1000个样本,其中900个用于训练,100个用于测试。
数据集可视化:
数据集获取:见csdn资源下载(免费)
实验代码
# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:35
# @IDE : PyCharm(2022.3.2) Python3.9.13
import os
import operator
import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
def classify0(inX, dataSet, labels, k):
"""手写kNN算法的实现
:param inX: 新样本
:param dataSet: 已知样本
:param labels: 已知样本标签
:param k: 近邻样本数量
:return: 分类结果sortedClassCount[0][0]
"""
dataSetSize = dataSet.shape[0] # 读取数据集第一维长度
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # 对新样本构造数组
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
# 计算新样本与所有已知样本的欧氏距离
distances = sqDistances ** 0.5
sortedDistIndices = distances.argsort()
classCount = {
}
# 统计最近的前k的已知样本
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 统计前k个已知样本中最多的类别,作为新样本的类别
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def file2matrix(filename):
"""读取数据
读取指定文件的内容,文件每一行为一个样本,共有4列,以制表符分割,前3列为属性(飞行里程数、联系客户次数、促成交易数),
第4列为标签(公司对业务员的评价)。将属性与标签分别存储,并将标签转化为整形。
:param filename: 文件名
:return: 属性集returnMat, 标签集classLabelVector
"""
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
returnMat = np.zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
if listFromLine[-1] == 'didntLike':
classLabelVector.append(1)
elif listFromLine[-1] == 'smallDoses':
classLabelVector.append(2)
elif listFromLine[-1] == 'largeDoses':
classLabelVector.append(3)
index += 1
return returnMat, classLabelVector
def autoNorm(dataSet):
"""进行Min-Max归一化
获取所有样本各个属性的最小值minVals,最大值maxVals,利用Min-Max归一化公式进行归一化。
在样本各属性尺度不同时,kNN法必须进行归一化处理,
任务1中不是必须进行归一化处理,是因为事先已知样本的各属性是同一尺度的(都是电影中某场景出现次数)
:param dataSet: 原始属性集
:return: 归一化数据结果normDataSet(1000个), 数据范围ranges, 最小值minVals
"""
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals # 得到数据集范围
normDataSet = np.zeros(np.shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m, 1))
normDataSet = normDataSet / np.tile(ranges, (m, 1))
# print(normDataSet)
return normDataSet, ranges, minVals
def salemanClassTest():
"""打开salemanTestSet.txt文件,取前10%数据为测试集,后90%数据为训练集,输出测试结果与错误率"""
filename = "salemanTestSet.txt"
salemanDataMat, salemanLabels = file2matrix(filename) # 得到训练样本和标签
hoRatio = 0.10 # 测试集占比
normMat, ranges, minVals = autoNorm(salemanDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0 # 初始错误率
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :],
salemanLabels[numTestVecs:m], 4)
print("分类结果:%d\t真实类别:%d" % (classifierResult, salemanLabels[i]))
if classifierResult != salemanLabels[i]: # 分类错误
errorCount += 1.0
print("错误率:%f%%" % (errorCount / float(numTestVecs) * 100))
def visualization():
# matplotlib画图中中文显示会有问题,需要这两行设置默认字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 绘图设置
fig = plt.figure()
# ax = Axes3D(fig)
ax = fig.add_subplot(111, projection='3d')
# 数据获取
filename = "salemanTestSet.txt"
salemanDataMat, salemanLabels = file2matrix(filename)
normMat, ranges, minVals = autoNorm(salemanDataMat)
# 数据分类
for i, da in enumerate(normMat):
X = da[0]
Y = da[1]
Z = da[2]
if salemanLabels[i] == 1:
ax.scatter(X, Y, Z, c='r', marker='o')
elif salemanLabels[i] == 2:
ax.scatter(X, Y, Z, c='b', marker='o')
elif salemanLabels[i] == 3:
ax.scatter(X, Y, Z, c='g', marker='o')
else:
raise "Error..."
plt.legend(["didntLike", "smallDoses", "largeDoses"])
ax.set_xlabel('飞行里程数')
ax.set_ylabel('联系客户次数')
ax.set_zlabel('促成交易数')
plt.title('样本可视化') # 图的标题
path = os.getcwd() # 获取当前的工作路径
fileName = "002"
filePath = path + "\\" + fileName + ".png"
plt.savefig(filePath, dpi=600) # dpi越大,图像越清晰,当然图像所占的存储也大
plt.show()
if __name__ == '__main__':
salemanClassTest()
#visualization() 数据集可视化
实验结果
使用K-近邻法对这100个样本的分类部分结果如上,总分类错误率为4%,也就是仅仅分错了4张,仅看分类错误率来说效果还是比较理想的。
五、sk-learn实现k-近邻法,完成手写字体识别任务
对于此数据集在前面的博文 前馈神经网络与支持向量机实战 中做过详细的说明,这里不做过多的解释。
实验代码
# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:41
# @IDE : PyCharm(2022.3.2) Python3.9.13
from os import listdir
import numpy as np
from sklearn.neighbors import KNeighborsClassifier as kNN
def img2vector(filename):
"""将图片展开
将32*32的二维数组转化为长度1024的一维数组
:param filename: 图片文件
:return: 转换后的图片文件returnVect
"""
returnVect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
# 读取trainingDigits文件夹下的文件
# 文件是32*32的矩阵,代表手写数字,数字笔画用1表示,空白处用0表示,
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
# 文件名的第一个数字(下划线前的数字)是该手写数字的内容,也是样本标签
classNumber = int(fileNameStr.split('_')[0])
hwLabels.append(classNumber)
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
# 使用sk-learn的kNN对象构建kNN分类器
neigh = kNN(n_neighbors=3, algorithm='auto')
# 使用fit方法拟合
neigh.fit(trainingMat, hwLabels)
# 读取testDigits文件夹下的文件
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
# 遍历每一个文件,转换其属性(像素点矩阵),获取其标签,使用predict方法进行预测,并统计错误率
for i in range(mTest):
fileNameStr = testFileList[i]
classNumber = int(fileNameStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
classifierResult = neigh.predict(vectorUnderTest)
print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
if classifierResult != classNumber:
errorCount += 1.0
print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount / mTest * 100))
if __name__ == '__main__':
handwritingClassTest()
实验结果
总体分类错误率错误率为1.268499%,比较理想。
六、sk-learn实现决策树,完成隐形眼镜评估任务
实验代码
# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:46
# @IDE : PyCharm(2022.3.2) Python3.9.13
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
def main():
# 读取lenses.txt文件,该文件的每一行为一个样本
# 前4列为属性(分别为'age', 'prescript', 'astigmatic', 'tearRate'),最后1列为标签。
with open('lenses.txt', 'r') as fr:
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_target = []
for each in lenses:
lenses_target.append(each[-1])
print(lenses_target)
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lenses_list = []
lenses_dict = {
}
for each_label in lensesLabels:
for each in lenses:
lenses_list.append(each[lensesLabels.index(each_label)])
lenses_dict[each_label] = lenses_list
lenses_list = []
# 将数据保存在pandas的DataFrame中
lenses_pd = pd.DataFrame(lenses_dict)
le = LabelEncoder()
for col in lenses_pd.columns:
lenses_pd[col] = le.fit_transform(lenses_pd[col])
# 将sk-learn的DecisionTreeClassifier类进行实例化,生成分类器,参数为max_depth = 4
clf = tree.DecisionTreeClassifier(max_depth=4)
# 使用fit方法进行训练,输入属性集与标签集
clf.fit(lenses_pd.values.tolist(), lenses_target)
# 使用predict方法进行测试,测试集与训练集相同,观察决策树是否得出了与真实标签相同的结论。
print(clf.predict(lenses_pd.values.tolist()))
if __name__ == '__main__':
main()
实验结果
七、手写决策树,完成银行客户信誉评估任务
实验代码
# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:51
# @IDE : PyCharm(2022.3.2) Python3.9.13
"""
提示:本任务可能涉及香农熵的计算、决策树最优属性的选择,其中可能会使用递归方法,如果遇到困难,请回顾一下Python基础知识,
如递归的使用方法,以及python中函数内部修改变量和列表时,其它位置的变量和列表是否同步的规律
"""
import math
import operator
def calcShannonEnt(dataSet):
numEntires = len(dataSet)
labelCounts = {
}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntires
shannonEnt -= prob * math.log(prob, 2)
return shannonEnt
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
return dataSet, labels
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {
}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1 or len(labels) == 0:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree = {
bestFeatLabel: {
}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
def classify(inputTree, featLabels, testVec):
firstStr = next(iter(inputTree))
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
testVec = [0, 1]
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('测试样本%s为: 高信用' % testVec)
if result == 'no':
print('测试样本%s为: 低信用' % testVec)