Decision tree - classification based on fruit characteristics

1. Get the dataset

Among the fruits, apples and carambola have relatively distinct external features. For example, in the following two pictures of apples and carambolas, the apples are red in color, roughly oval in shape, smooth without corners, and have leaves. Carambola is yellow, pentagram-shaped, and has corners. , No leaves.
insert image description here
Use the above characteristics to count some apple and carambola data:

  • Color: 1-red 0-yellow
  • Shape: 1-ellipse 0-pentagram
  • Edges: 1-with edges 0-without edges
  • With leaves: 1-with leaves 0-without leaves

insert image description here

1. Extract data

Use the CSV library to classify the specified features, extract the data except the first row, and use it as the data set for this experiment. The
insert image description here
first row is each node of the decision tree, which is stored in labels; and then the features correspond to each situation. Stored in labels.
insert image description here

# 获取数据集
def createDataSet(filename):
    # 读取文件
    data = open(filename, 'rt', encoding='gbk')
    reader = csv.reader(data)
    # 获取标签列
    handlers = next(reader)
    lables = handlers[:-1]
    # 数据列表
    dataSet = []

    for row in reader:
        # 读取除第一行的数据
        dataSet.append(row[:])
        
    # 特征对应的所有可能的情况
    labels_full = {
    
    }
    for i in range(len(lables)):
        labelList = [example[i] for example in dataSet]
        uniqueLabel = set(labelList)
        labels_full[lables[i]] = uniqueLabel
    return dataSet, lables, labels_full

2. Divide the data

For the data input by dataSet, axis is the corresponding coordinate in labels, and value is the attribute value under the corresponding attribute.

# 划分数据集  
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        # 给定特征值等于想要的特征值
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            # 将该特征值后面的内容保存起来
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)

    return retDataSet
print(splitDataSet(dataSet, 1, '0'))

Use the majorityCnt method to get the label with the most occurrences in a collection

# 获取出现次数最多的类别
def majorityCnt(classList):
    classCount = collections.defaultdict(int)
    # 遍历所有的类别
    for vote in classList:
        classCount[vote] += 1
    # 降序排序,第一行第一列就是最多的
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

Second, calculate the information gain

1. Information entropy

First get all the data length, and then create a dictionary, the key value is the last column value. Each key value records the number of occurrences of the current category, and finally calculates the occurrence rate of all class labels to calculate the rate of occurrence of the category, and finally calculates the entropy value.

# 获取水果信息熵
def calcShannonEnt(dataSet):
    # 总数
    numEntries = len(dataSet)
    # 用来统计标签
    labelCounts = collections.defaultdict(int)
    # 循环整个数据集,得到数据的分类标签
    for featVec in dataSet:
        # 得到当前的标签
        currentLabel = featVec[-1]
        labelCounts[currentLabel] += 1
    # 计算信息熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

2. Calculate information gain

To calculate information gain, first obtain the number of all features, excluding the final fruit classification; then calculate the information entropy corresponding to each feature; finally subtract the information entropy of the classification from the information entropy of the feature, which is the information gain of the corresponding feature. After obtaining the information gain of each feature, return the label subscript corresponding to the maximum value, and use it as the root node of the number when building the decision tree.
The information gain corresponding to each feature, and finally returns the subscript corresponding to the largest label:
insert image description here

# 计算每个特征信息增益
def chooseBestFeatureToSplit(dataSet, labels):
    # 特征数 总的列数减去最后的一列
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    # 对每个特征值进行求信息熵
    for i in range(numFeatures):
        # 得到数据集中所有的当前特征值列表
        featList = [example[i] for example in dataSet]
        # 当前特征值中共有多少种
        uniqueVals = set(featList)
        newEntropy = 0.0

        # 遍历现在有的特征的可能性
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy

        print( labels[i] + '信息增益值为:' + str(infoGain))
        # 找出最大的值
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
print(chooseBestFeatureToSplit(dataSet, lables))

3. Draw a decision tree

Input a dataset and an array of labels to get a dictionary-like decision tree.
First get the classification labels of all data sets, and then count the number of occurrences of the first label, and compare it with the total number of labels. Calculate how many data there are in the first line. If there is only one, it means that all feature attributes have been traversed, and the remaining one is the category label, or all samples are consistent in all attributes, and then return the number of occurrences in the remaining labels by majorityCntusing The one with more. After chooseBestFeatureToSplitselecting the best division feature, get the subscript of the feature as the root node. Finally, it is called recursively to divide all the data in the data set whose feature is equal to the current feature value into the current node. When calling recursively, the current feature needs to be removed first.

# 绘制决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    print(classList)
    # 统计第一个标签出现的次数,与总标签个数比较
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    if len(dataSet[0]) == 1 :
        # 返回剩下标签中出现次数较多的那个
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet=dataSet, labels=labels)
    bestFeatLabel = labels[bestFeat]

    myTree = {
    
    bestFeatLabel: {
    
    }}

    # 将本次划分的特征值从列表中删除掉
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    # 遍历所有的特征值
    for value in uniqueVals:
        subLabels = labels[:]
        subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels)
        # 递归调用
        myTree[bestFeatLabel][value] = subTree
    return myTree
print(createTree(dataSet, lables))

Get a dictionary-shaped decision tree:

{'带叶': 
    {'1': {'形状': 
            {'1': '苹果', 
            '0': {'棱角': 
                {'1': '杨桃', 
                 '0': '苹果'}}}}, 
     '0': {'棱角': 
            {'1': '杨桃', 
             '0': {'颜色': 
                {'1': {'形状': {'杨桃': '杨桃', '苹果': '苹果'}}, 
                 '0': {'形状': {'杨桃': '杨桃', '苹果': '苹果'}}}}}}}}

4. Classification prediction

Class prediction is also a recursive function that uses the index method to find the first element in the current list that matches the firstStr variable. Then recursively traverse the entire tree, compare the value in the testVec variable with the value of the tree node, and return the classification label if it reaches the leaf node.

# 预测
def classify(inTree, featLabel, testVec):
    # 获取第一个节点
    firstStr = list(inTree.keys())[0]
    secondDict = inTree[firstStr]
    # 节点对应下标
    featIndex = featLabel.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            # 递归判断
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabel, testVec)
            else: classLabel = secondDict[key]
    # 返回预测
    return classLabel

Test Results:
insert image description here

Code:
Link: https://pan.baidu.com/s/1gjbXKDworG7ejzS6cCTvgQ?pwd=kupj
Extraction code: kupj

Guess you like

Origin blog.csdn.net/chenxingxingxing/article/details/127837664