K-Nearest Neighbor/Decision Tree Combat --- Classification Recognition Evaluation

K-Nearest Neighbor/Decision Tree Combat - Classification Recognition Evaluation

1. Introduction to k-Nearest Neighbors

What is the nearest neighbor method:

insert image description here
insert image description here
insert image description here

The K-Nearest Neighbor (KNN) classification algorithm is a relatively mature method in theory and one of the simplest machine learning algorithms. The idea of ​​this method is: in the feature space, if most of the k nearest samples near a sample (that is, the nearest neighbors in the feature space) belong to a certain category, then the sample also belongs to this category.
insert image description here
As shown in the figure above, there are two different types of sample data, represented by small blue squares and small red triangles, and the data marked by the green circle in the middle of the figure is the data to be classified. That is to say, now, we don't know which category the green data in the middle belongs to (blue small square or red small triangle), and next, we will solve this problem: classify this green circle.

We often say that birds of a feather flock together, and people are divided into groups. To judge what kind of character and characteristics a person is, you can often start with the friends around him/her. Aren't we going to determine which type of data the green circle in Figure 1 belongs to? Well, let's start with its neighbors. But how many neighbors do you see at once? From Figure 1, you can also see:

  • If K=3, the nearest 3 neighbors of the green dot are 2 small red triangles and 1 small blue square, and the minority belongs to the majority. Based on the statistical method, it is determined that the green point to be classified belongs to the red triangle. kind.
  • If K=5, the nearest 5 neighbors of the green dot are 2 red triangles and 3 blue squares, or the minority belongs to the majority, based on statistical methods, it is determined that the green dot to be classified belongs to the blue square one type.

Here we see that when it is impossible to determine which category the current point to be classified belongs to in the known classification, we can look at its location characteristics based on statistical theory, measure the weight of its neighbors, and put It is classified (or assigned) to the category with greater weight. This is the core idea of ​​the K-nearest neighbor algorithm.

In the KNN algorithm, the selected neighbors are objects that have been correctly classified. In the classification decision, this method only determines the category of the sample to be divided according to the category of the nearest one or several samples.

The KNN algorithm itself is simple and effective. It is a lazy-learning algorithm. The classifier does not need to use the training set for training, and the training time complexity is 0. The computational complexity of KNN classification is proportional to the number of documents in the training set, that is, if the total number of documents in the training set is n, then the time complexity of KNN classification is O(n).

Although the KNN method also relies on the limit theorem in principle, it is only related to a very small number of adjacent samples when making category decisions. Since the KNN method mainly relies on the limited surrounding samples rather than the method of discriminating the class domain to determine the category to which it belongs, the KNN method is more accurate than other methods for the sample sets to be divided when the class domain crosses or overlaps more. for fit.

The model used by the K-nearest neighbor algorithm actually corresponds to the division of the feature space. The choice of K value, the distance measure and the classification decision rule are the three basic elements of the algorithm:

  1. The choice of the value of K can have a significant impact on the results of the algorithm. A small value of K means that only training examples that are closer to the input instance will have an effect on the prediction results, but overfitting is prone to occur; if the value of K is large, the advantage is that the estimation error of learning can be reduced, but the disadvantage is that the learning The approximation error increases, and at this time, the training examples that are far away from the input instance will also play a role in the prediction, making the prediction wrong. In practical applications, the K value generally chooses a smaller value, and the cross-validation method is usually used to select the optimal K value. As the number of training examples tends to infinity and K=1, the error rate will not exceed 2 times the Bayesian error rate. If K also tends to infinity, the error rate tends to the Bayesian error rate.
  2. The classification decision rule in this algorithm is often a majority vote, that is, the category of the input instance is determined by the majority class among the K nearest training instances of the input instance
  3. The distance measurement generally adopts the Lp distance. When p=2, it is the Euclidean distance. Before the measurement, the value of each attribute should be normalized, which helps to prevent the attribute with a larger initial value range from having a smaller initial value. Value domain attributes are given too much weight.

The KNN algorithm can be used not only for classification, but also for regression. By finding the k nearest neighbors of a sample and assigning the average value of the attributes of these neighbors to the sample, the attributes of the sample can be obtained. A more useful method is to give different weights to the influence of neighbors with different distances on the sample , such as the weight is inversely proportional to the distance. The main disadvantage of this algorithm in classification is that when the samples are unbalanced, such as the sample size of one class is large, while the sample size of other classes is small, it may cause that when a new sample is input, K The samples of the bulk class in the neighborhood are in the majority. The algorithm only calculates the "nearest" neighbor samples. If the number of samples of a certain class is large, then either such samples are not close to the target sample, or such samples are very close to the target sample. In any case, the quantity does not affect the running results. It can be improved by using the weight method (the neighbor with a small distance from the sample has a large weight).

Another disadvantage of this method is that the amount of calculation is large, because for each text to be classified, the distance to all known samples must be calculated to obtain its K nearest neighbors. The commonly used solution at present is to edit known sample points in advance, and remove samples that have little effect on classification in advance. This algorithm is more suitable for the automatic classification of class domains with relatively large sample size , while those class domains with small sample size are more prone to misclassification when using this algorithm.

When implementing the K-nearest neighbor algorithm, the main consideration is how to perform fast K-nearest neighbor search on the training data, which is very necessary when the dimension of the feature space is large and the capacity of the training data is large.

The choice of k value is usually a difficult decision. It can be selected through experience combined with data sets, or a set of k values ​​can be tested through cross-validation (mentioned below), and the k value with the highest score can be selected as k for subsequent training. The choice of k value is generally an odd number (to avoid the embarrassing situation of a tie vote when voting)

One sentence summary : The so-called near vermilion is red and black is close to ink, the core of the K-nearest neighbor method is to select the K "neighbors" closest to it.

Three elements: k value selection, distance measure, classification decision rule

Pros and cons: high accuracy, insensitivity to outliers, no data input assumptions. However, the computational complexity is high and the space complexity is high.

2. Introduction to decision tree

Quantization of non-numerical features

  • Nominal Features: Orthogonal Encoding

    – such as color, shape, gender, occupation, characters in strings, etc.

  • Ordinal features: equivalent to nominal features or converted to numerical features

    – For example, serial numbers and classifications cannot be regarded as values ​​in Euclidean space

  • Interval features: become binary features or ordinal features by setting thresholds

    – The relationship with the research objective presents an obvious non-linearity. The value is a real number, which can be compared, but there is no "natural" zero, and the ratio is meaningless

    – such as age, temperature, test scores, etc.

A simplified example of a tree decision process
insert image description here

Decision Tree (Decision Tree is a decision analysis method for evaluating project risk and judging its feasibility by forming a decision tree to obtain the probability that the expected value of the net present value is greater than or equal to zero on the basis of knowing the probability of occurrence of various situations. A graphical method using probability analysis. Because this kind of decision-making branch is drawn as a graph that resembles the branches of a tree, it is called a decision tree. In machine learning, a decision tree is a predictive model, which represents object attributes and A mapping relationship between object values. Entropy = the degree of chaos of the system, using the algorithm ID3 , C4.5 and C5.0 spanning tree algorithms use entropy. This measurement is based on the concept of entropy in the theory of information science .

A decision tree is a tree structure in which each internal node represents a test on an attribute, each branch represents a test output, and each leaf node represents a category.

Classification tree (decision tree) is a very commonly used classification method. It is a kind of supervised learning, the so-called supervised learning is given a bunch of samples, each sample has a set of attributes and a category, these categories are determined in advance, then a classifier is obtained through learning, this classifier can classify new objects are given the correct classification. Such machine learning is called supervised learning.

insert image description here


(See the code comments for the specific process of the following practice content)

3. Handwritten k-nearest neighbor method to complete the movie classification task

Because it is an introductory task, the data here is relatively simple, and the specific data is shown in the code diagram of the experimental results.

experimental code

# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 21:18 
# @IDE : PyCharm(2022.3.2) Python3.9.13
import os
import operator
import numpy as np
import matplotlib.pyplot as plt


def createDataSet():
    """创建数据集
    创建四组二维特征与四组特征的标签,代表了4部电影中动作镜头与喜剧镜头出现次数。
    :return: 数据集group、分类标签labels
    """
    group = np.array([[1, 101], [5, 89], [108, 5], [115, 8]])
    # print(group)
    labels = ['喜剧片', '喜剧片', '动作片', '动作片']
    return group, labels


def classify0(inX, dataSet, labels, k):
    """kNN算法的实现
    :param inX: 新样本
    :param dataSet: 已知样本
    :param labels: 已知样本标签
    :param k: 近邻样本数量
    :return: 分类结果sortedClassCount[0][0]
    """
    dataSetSize = dataSet.shape[0]  # 读取数据集第一维长度
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet  # 对新样本构造数组
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    # 计算新样本与所有已知样本的欧氏距离
    distances = sqDistances ** 0.5
    sortedDistIndices = distances.argsort()
    classCount = {
    
    }
    # 统计最近的前k的已知样本
    for i in range(k):
        voteIlabel = labels[sortedDistIndices[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    # 统计前k个已知样本中最多的类别,作为新样本的类别
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def visualization():
    """
    将坐标进行可视化展示
    :return:
    """
    # matplotlib画图中中文显示会有问题,需要这两行设置默认字体
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False

    # 0 数据准备
    X = np.array([1, 5])
    Y = np.array([101, 89])
    X2 = np.array([108, 115])
    Y2 = np.array([5, 8])

    xmx = X.max() + 10
    ymx = Y.max() + 10

    # 1 设置x,y坐标轴的刻度显示范围
    fig = plt.figure()
    plt.xlim(xmin=0, xmax=120)  # x轴的范围
    plt.ylim(ymin=0, ymax=120)  # y轴的范围
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('电影分类可视化')  # 图的标题

    # 2 画图显示
    plt.scatter(X, Y, marker='o', alpha=1, color="blue", label='喜剧片')
    plt.scatter(X2, Y2, marker='x', alpha=1, color="green", label='动作片')
    plt.scatter(101, 20, marker='*', alpha=1, color="red", label='测试样本')
    plt.legend()  # label='类别A' 图中显示

    # 3 保存图像
    path = os.getcwd()  # 获取当前的工作路径
    fileName = "001"
    filePath = path + "\\" + fileName + ".png"
    plt.savefig(filePath, dpi=600)  # dpi越大,图像越清晰,当然图像所占的存储也大

    plt.show()  # 显示图像在保存后调用,否则会使保存的图片为空白


if __name__ == '__main__':
    group, labels = createDataSet()  # 调用createDataSet()创建数据集
    test = [101, 20]  # 测试样本
    test_class = classify0(test, group, labels, 3)  # 执行kNN分类
    print('本电影分类为: ' + test_class)
    visualization()  # 可视化展示

Experimental results

insert image description here

Result visualization:
insert image description here

It can be seen from the picture that the test sample is closer to an action movie, so the classification is also an action movie, and the classification result is correct.

4. Handwriting the k-nearest neighbor method to complete the salesman evaluation task

Dataset description

An example of the original dataset is as follows:
insert image description here

Each line of the file is a sample, with a total of 4 columns, separated by tabs, the first 3 columns are attributes (mileage of flights, number of times of contacting customers, number of transactions facilitated), the fourth column is labels (company's evaluation of salespersons), There are 1000 samples in total, 900 for training and 100 for testing.

Dataset visualization:
insert image description here

Dataset acquisition: see csdn resource download (free)

experimental code

# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:35 
# @IDE : PyCharm(2022.3.2) Python3.9.13
import os
import operator
import numpy as np
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D


def classify0(inX, dataSet, labels, k):
    """手写kNN算法的实现
    :param inX: 新样本
    :param dataSet: 已知样本
    :param labels: 已知样本标签
    :param k: 近邻样本数量
    :return: 分类结果sortedClassCount[0][0]
    """
    dataSetSize = dataSet.shape[0]  # 读取数据集第一维长度
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet  # 对新样本构造数组
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    # 计算新样本与所有已知样本的欧氏距离
    distances = sqDistances ** 0.5
    sortedDistIndices = distances.argsort()
    classCount = {
    
    }
    # 统计最近的前k的已知样本
    for i in range(k):
        voteIlabel = labels[sortedDistIndices[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    # 统计前k个已知样本中最多的类别,作为新样本的类别
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def file2matrix(filename):
    """读取数据
    读取指定文件的内容,文件每一行为一个样本,共有4列,以制表符分割,前3列为属性(飞行里程数、联系客户次数、促成交易数),
    第4列为标签(公司对业务员的评价)。将属性与标签分别存储,并将标签转化为整形。
    :param filename: 文件名
    :return: 属性集returnMat, 标签集classLabelVector
    """
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = np.zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index, :] = listFromLine[0:3]
        if listFromLine[-1] == 'didntLike':
            classLabelVector.append(1)
        elif listFromLine[-1] == 'smallDoses':
            classLabelVector.append(2)
        elif listFromLine[-1] == 'largeDoses':
            classLabelVector.append(3)
        index += 1
    return returnMat, classLabelVector


def autoNorm(dataSet):
    """进行Min-Max归一化
    获取所有样本各个属性的最小值minVals,最大值maxVals,利用Min-Max归一化公式进行归一化。
    在样本各属性尺度不同时,kNN法必须进行归一化处理,
    任务1中不是必须进行归一化处理,是因为事先已知样本的各属性是同一尺度的(都是电影中某场景出现次数)
    :param dataSet: 原始属性集
    :return: 归一化数据结果normDataSet(1000个), 数据范围ranges, 最小值minVals
    """
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals  # 得到数据集范围
    normDataSet = np.zeros(np.shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - np.tile(minVals, (m, 1))
    normDataSet = normDataSet / np.tile(ranges, (m, 1))
    # print(normDataSet)
    return normDataSet, ranges, minVals


def salemanClassTest():
    """打开salemanTestSet.txt文件,取前10%数据为测试集,后90%数据为训练集,输出测试结果与错误率"""
    filename = "salemanTestSet.txt"
    salemanDataMat, salemanLabels = file2matrix(filename)  # 得到训练样本和标签
    hoRatio = 0.10  # 测试集占比
    normMat, ranges, minVals = autoNorm(salemanDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m * hoRatio)
    errorCount = 0.0  # 初始错误率

    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :],
                                     salemanLabels[numTestVecs:m], 4)
        print("分类结果:%d\t真实类别:%d" % (classifierResult, salemanLabels[i]))
        if classifierResult != salemanLabels[i]:  # 分类错误
            errorCount += 1.0
    print("错误率:%f%%" % (errorCount / float(numTestVecs) * 100))


def visualization():
    # matplotlib画图中中文显示会有问题,需要这两行设置默认字体
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False

    # 绘图设置
    fig = plt.figure()
    # ax = Axes3D(fig)
    ax = fig.add_subplot(111, projection='3d')

    # 数据获取
    filename = "salemanTestSet.txt"
    salemanDataMat, salemanLabels = file2matrix(filename)
    normMat, ranges, minVals = autoNorm(salemanDataMat)
    # 数据分类
    for i, da in enumerate(normMat):
        X = da[0]
        Y = da[1]
        Z = da[2]
        if salemanLabels[i] == 1:
            ax.scatter(X, Y, Z, c='r', marker='o')
        elif salemanLabels[i] == 2:
            ax.scatter(X, Y, Z, c='b', marker='o')
        elif salemanLabels[i] == 3:
            ax.scatter(X, Y, Z, c='g', marker='o')
        else:
            raise "Error..."
    plt.legend(["didntLike", "smallDoses", "largeDoses"])
    ax.set_xlabel('飞行里程数')
    ax.set_ylabel('联系客户次数')
    ax.set_zlabel('促成交易数')
    plt.title('样本可视化')  # 图的标题

    path = os.getcwd()  # 获取当前的工作路径
    fileName = "002"
    filePath = path + "\\" + fileName + ".png"
    plt.savefig(filePath, dpi=600)  # dpi越大,图像越清晰,当然图像所占的存储也大

    plt.show()


if __name__ == '__main__':
    salemanClassTest()
    #visualization() 数据集可视化

Experimental results

insert image description here

Using the K-Nearest Neighbor method to classify these 100 samples, the partial results are as above, and the total classification error rate is 4%, that is, only 4 samples were classified incorrectly. The effect is quite ideal based on the classification error rate alone.

5. sk-learn implements the k-nearest neighbor method to complete the task of handwritten font recognition

This data set has been described in detail in the previous blog post Feedforward Neural Network and Support Vector Machine Actual Combat, so I won't explain too much here.

experimental code

# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:41 
# @IDE : PyCharm(2022.3.2) Python3.9.13

from os import listdir
import numpy as np
from sklearn.neighbors import KNeighborsClassifier as kNN


def img2vector(filename):
    """将图片展开
    将32*32的二维数组转化为长度1024的一维数组
    :param filename: 图片文件
    :return: 转换后的图片文件returnVect
    """
    returnVect = np.zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])
    return returnVect


def handwritingClassTest():
    hwLabels = []
    # 读取trainingDigits文件夹下的文件
    # 文件是32*32的矩阵,代表手写数字,数字笔画用1表示,空白处用0表示,
    trainingFileList = listdir('trainingDigits')
    m = len(trainingFileList)
    trainingMat = np.zeros((m, 1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        # 文件名的第一个数字(下划线前的数字)是该手写数字的内容,也是样本标签
        classNumber = int(fileNameStr.split('_')[0])
        hwLabels.append(classNumber)
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
    # 使用sk-learn的kNN对象构建kNN分类器
    neigh = kNN(n_neighbors=3, algorithm='auto')
    # 使用fit方法拟合
    neigh.fit(trainingMat, hwLabels)
    # 读取testDigits文件夹下的文件
    testFileList = listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    # 遍历每一个文件,转换其属性(像素点矩阵),获取其标签,使用predict方法进行预测,并统计错误率
    for i in range(mTest):
        fileNameStr = testFileList[i]
        classNumber = int(fileNameStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
        classifierResult = neigh.predict(vectorUnderTest)
        print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
        if classifierResult != classNumber:
            errorCount += 1.0
    print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount / mTest * 100))


if __name__ == '__main__':
    handwritingClassTest()

Experimental results

insert image description here
The overall classification error rate is 1.268499%, which is ideal.

6. sk-learn realizes the decision tree and completes the contact lens evaluation task

experimental code

# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:46 
# @IDE : PyCharm(2022.3.2) Python3.9.13

import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder


def main():
    # 读取lenses.txt文件,该文件的每一行为一个样本
    # 前4列为属性(分别为'age', 'prescript', 'astigmatic', 'tearRate'),最后1列为标签。
    with open('lenses.txt', 'r') as fr:
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lenses_target = []
    for each in lenses:
        lenses_target.append(each[-1])
    print(lenses_target)

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lenses_list = []
    lenses_dict = {
    
    }
    for each_label in lensesLabels:
        for each in lenses:
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
        lenses_list = []

    # 将数据保存在pandas的DataFrame中
    lenses_pd = pd.DataFrame(lenses_dict)

    le = LabelEncoder()
    for col in lenses_pd.columns:
        lenses_pd[col] = le.fit_transform(lenses_pd[col])

    # 将sk-learn的DecisionTreeClassifier类进行实例化,生成分类器,参数为max_depth = 4
    clf = tree.DecisionTreeClassifier(max_depth=4)
    # 使用fit方法进行训练,输入属性集与标签集
    clf.fit(lenses_pd.values.tolist(), lenses_target)
    # 使用predict方法进行测试,测试集与训练集相同,观察决策树是否得出了与真实标签相同的结论。
    print(clf.predict(lenses_pd.values.tolist()))


if __name__ == '__main__':
    main()

Experimental results

insert image description here

7. Write a decision tree by hand to complete the credit evaluation task of bank customers

experimental code

# -*- coding: utf-8 -*-
# @Author : Xenon
# @Date : 2023/2/4 22:51 
# @IDE : PyCharm(2022.3.2) Python3.9.13
"""
提示:本任务可能涉及香农熵的计算、决策树最优属性的选择,其中可能会使用递归方法,如果遇到困难,请回顾一下Python基础知识,
如递归的使用方法,以及python中函数内部修改变量和列表时,其它位置的变量和列表是否同步的规律
"""
import math
import operator


def calcShannonEnt(dataSet):
    numEntires = len(dataSet)
    labelCounts = {
    
    }
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntires
        shannonEnt -= prob * math.log(prob, 2)
    return shannonEnt


def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
    return dataSet, labels


def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy

        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {
    
    }
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels, featLabels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1 or len(labels) == 0:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    featLabels.append(bestFeatLabel)
    myTree = {
    
    bestFeatLabel: {
    
    }}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return myTree


def classify(inputTree, featLabels, testVec):
    firstStr = next(iter(inputTree))
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet, labels, featLabels)
    testVec = [0, 1]
    result = classify(myTree, featLabels, testVec)
    if result == 'yes':
        print('测试样本%s为: 高信用' % testVec)
    if result == 'no':
        print('测试样本%s为: 低信用' % testVec)

Experimental results

insert image description here

Guess you like

Origin blog.csdn.net/yxn4065/article/details/128909790