Árbol de decisión - clasificación basada en las características de la fruta

1. Obtenga el conjunto de datos

Entre las frutas, las manzanas y la carambola tienen características externas relativamente distintas. Por ejemplo, en las siguientes dos imágenes de manzanas y carambolas, las manzanas son de color rojo, de forma aproximadamente ovalada, lisas sin esquinas y tienen hojas. La carambola es amarilla, en forma de pentagrama, y ​​tiene esquinas. , Sin hojas.
inserte la descripción de la imagen aquí
Utilice las características anteriores para contar algunos datos de manzana y carambola:

  • Color: 1-rojo 0-amarillo
  • Forma: 1-elipse 0-pentagrama
  • Aristas: 1-con aristas 0-sin aristas
  • Con hojas: 1-con hojas 0-sin hojas

inserte la descripción de la imagen aquí

1. Extraer datos

Use la biblioteca CSV para clasificar las características especificadas, extraiga los datos excepto la primera fila y utilícelos como el conjunto de datos para este experimento. La
inserte la descripción de la imagen aquí
primera fila es cada nodo del árbol de decisión, que se almacena en etiquetas; y luego las características corresponde a cada situación Almacenado en etiquetas.
inserte la descripción de la imagen aquí

# 获取数据集
def createDataSet(filename):
    # 读取文件
    data = open(filename, 'rt', encoding='gbk')
    reader = csv.reader(data)
    # 获取标签列
    handlers = next(reader)
    lables = handlers[:-1]
    # 数据列表
    dataSet = []

    for row in reader:
        # 读取除第一行的数据
        dataSet.append(row[:])
        
    # 特征对应的所有可能的情况
    labels_full = {
    
    }
    for i in range(len(lables)):
        labelList = [example[i] for example in dataSet]
        uniqueLabel = set(labelList)
        labels_full[lables[i]] = uniqueLabel
    return dataSet, lables, labels_full

2. Divide los datos

Para la entrada de datos por dataSet, el eje es la coordenada correspondiente en las etiquetas y el valor es el valor del atributo debajo del atributo correspondiente.

# 划分数据集  
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        # 给定特征值等于想要的特征值
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            # 将该特征值后面的内容保存起来
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)

    return retDataSet
print(splitDataSet(dataSet, 1, '0'))

Usa el método majorCnt para obtener la etiqueta con la mayor cantidad de ocurrencias en una colección

# 获取出现次数最多的类别
def majorityCnt(classList):
    classCount = collections.defaultdict(int)
    # 遍历所有的类别
    for vote in classList:
        classCount[vote] += 1
    # 降序排序,第一行第一列就是最多的
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

En segundo lugar, calcule la ganancia de información.

1. Entropía de la información

Primero obtenga toda la longitud de los datos y luego cree un diccionario, el valor clave es el valor de la última columna. Cada valor clave registra el número de ocurrencias de la categoría actual y finalmente calcula la tasa de ocurrencia de todas las etiquetas de clase para calcular la tasa de ocurrencia de la categoría y finalmente calcula el valor de entropía.

# 获取水果信息熵
def calcShannonEnt(dataSet):
    # 总数
    numEntries = len(dataSet)
    # 用来统计标签
    labelCounts = collections.defaultdict(int)
    # 循环整个数据集,得到数据的分类标签
    for featVec in dataSet:
        # 得到当前的标签
        currentLabel = featVec[-1]
        labelCounts[currentLabel] += 1
    # 计算信息熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

2. Calcular la ganancia de información

Para calcular la ganancia de información, primero obtenga el número de todas las características, excluyendo la clasificación final de la fruta; luego calcule la entropía de información correspondiente a cada característica; finalmente reste la entropía de información de la clasificación de la entropía de información de la característica, que es la ganancia de información de la función correspondiente. Después de obtener la ganancia de información de cada característica, devuelva el subíndice de la etiqueta correspondiente al valor máximo y utilícelo como el nodo raíz del número al construir el árbol de decisión.
La ganancia de información correspondiente a cada característica, y finalmente devuelve el subíndice correspondiente a la etiqueta más grande:
inserte la descripción de la imagen aquí

# 计算每个特征信息增益
def chooseBestFeatureToSplit(dataSet, labels):
    # 特征数 总的列数减去最后的一列
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    # 对每个特征值进行求信息熵
    for i in range(numFeatures):
        # 得到数据集中所有的当前特征值列表
        featList = [example[i] for example in dataSet]
        # 当前特征值中共有多少种
        uniqueVals = set(featList)
        newEntropy = 0.0

        # 遍历现在有的特征的可能性
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy

        print( labels[i] + '信息增益值为:' + str(infoGain))
        # 找出最大的值
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
print(chooseBestFeatureToSplit(dataSet, lables))

3. Dibujar un árbol de decisión

Ingrese un conjunto de datos y una matriz de etiquetas para obtener un árbol de decisiones similar a un diccionario.
Primero obtenga las etiquetas de clasificación de todos los conjuntos de datos y luego cuente la cantidad de ocurrencias de la primera etiqueta y compárela con la cantidad total de etiquetas. Calcule cuántos datos hay en la primera línea. Si solo hay uno, significa que se han recorrido todos los atributos de la entidad, y el restante es la etiqueta de categoría, o todas las muestras son consistentes en todos los atributos, y luego devuelve el número de ocurrencias en las etiquetas restantes usando majorityCntEl que tiene más. Después de chooseBestFeatureToSplitseleccionar la mejor función de división, obtenga el subíndice de la función como nodo raíz. Finalmente, se llama de forma recursiva para dividir todos los datos en el conjunto de datos cuya función es igual al valor de la función actual en el nodo actual. Cuando se llama de forma recursiva, la función actual debe eliminarse primero.

# 绘制决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    print(classList)
    # 统计第一个标签出现的次数,与总标签个数比较
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    if len(dataSet[0]) == 1 :
        # 返回剩下标签中出现次数较多的那个
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet=dataSet, labels=labels)
    bestFeatLabel = labels[bestFeat]

    myTree = {
    
    bestFeatLabel: {
    
    }}

    # 将本次划分的特征值从列表中删除掉
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    # 遍历所有的特征值
    for value in uniqueVals:
        subLabels = labels[:]
        subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels)
        # 递归调用
        myTree[bestFeatLabel][value] = subTree
    return myTree
print(createTree(dataSet, lables))

Obtenga un árbol de decisiones en forma de diccionario:

{'带叶': 
    {'1': {'形状': 
            {'1': '苹果', 
            '0': {'棱角': 
                {'1': '杨桃', 
                 '0': '苹果'}}}}, 
     '0': {'棱角': 
            {'1': '杨桃', 
             '0': {'颜色': 
                {'1': {'形状': {'杨桃': '杨桃', '苹果': '苹果'}}, 
                 '0': {'形状': {'杨桃': '杨桃', '苹果': '苹果'}}}}}}}}

4. Predicción de clasificación

La predicción de clases también es una función recursiva que usa el método index para encontrar el primer elemento en la lista actual que coincide con la variable firstStr. A continuación, recorra de forma recursiva todo el árbol, compare el valor de la variable testVec con el valor del nodo del árbol y devuelva la etiqueta de clasificación si llega al nodo hoja.

# 预测
def classify(inTree, featLabel, testVec):
    # 获取第一个节点
    firstStr = list(inTree.keys())[0]
    secondDict = inTree[firstStr]
    # 节点对应下标
    featIndex = featLabel.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            # 递归判断
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabel, testVec)
            else: classLabel = secondDict[key]
    # 返回预测
    return classLabel

Resultados de la prueba:
inserte la descripción de la imagen aquí

Código:
Enlace: https://pan.baidu.com/s/1gjbXKDworG7ejzS6cCTvgQ?pwd=kupj
Código de extracción: kupj

Supongo que te gusta

Origin blog.csdn.net/chenxingxingxing/article/details/127837664
Recomendado
Clasificación