Crea un árbol de decisiones

Disposición del artículo:
(1) Primero analice cómo dividir los datos desde la perspectiva de la teoría de la información
(2) Aplique fórmulas matemáticas al conjunto de datos real
(3) Dibuje un árbol de decisión
(4) Una rutina práctica
(5) Resumen experimental

1. ¿Qué es un árbol de decisiones?

Árbol de decisión, ver literalmente. Un árbol que puede tomar decisiones.

Como la detección de spam común (Naive Bayes también está bien) (Detecte la dirección de dominio del correo electrónico enviado, identifique la información en el correo electrónico, encuentre las palabras que a menudo aparecen en el correo no deseado como: descuento, gratis, compra) y luego tome una decisión sobre si es correo no deseado o correo electrónico de comunicación normal.

Por qué elegir un árbol de decisión
Ventaja: la complejidad computacional no es alta y el resultado es fácil de entender.
Desventaja: es fácil producir un ajuste excesivo.
Escenarios de uso: datos numéricos y nominales

2. Construcción del árbol de decisiones

Proceso:
(1) Recopilar datos: se puede utilizar cualquier método (el conjunto de datos se proporciona al final del artículo, jaja)
(2) Preparación de datos: el algoritmo de construcción de árboles solo se aplica a datos nominales, por lo que los datos numéricos deben discretizarse (similar al módulo de procesamiento de señales digitales Cuantificación de la amplitud de conversión)
(3) Analizar los datos: se puede usar cualquier método, después de que se construye el árbol, se debe verificar si el gráfico cumple con las expectativas en el tiempo
(4) algoritmo de entrenamiento: la estructura de datos del árbol de construcción
(5) algoritmo de prueba: cálculo usando el árbol de experiencia Tasa de error
(6) Usa algoritmo: escribe un blog que se pueda ejecutar sin errores, jajaja

Usando el algoritmo ID3, cada vez que el conjunto de datos se divide en una característica, ¿qué característica debe seleccionarse como base para nuestra división?

2.1 Ganancia de información

Base de clasificación: los datos similares se dividen en una categoría. Después de cada clasificación, los datos en la misma estructura de rama tienen mayor similitud. Y la eliminación de las características utilizadas como base para la clasificación hará que el conjunto de datosEl grado de desorden disminuye y los datos se vuelven más ordenados.

¿Cuál es la base para medir el orden de los datos?
La introducción de la teoría de la información
El cambio de información antes y después de dividir el conjunto de datos se denomina ganancia de información, y la característica con la mayor ganancia de información es la mejor base para nuestra clasificación.
La medida de la información agregada se llama entropía o entropía de Shannon. Cuanto
mayor es la entropía, mayor es el caos de los datos, más desordenados son los datos y menor la similitud.

La entropía es el valor esperado de la información. Si la transacción a clasificar se puede dividir en múltiples categorías, la información del símbolo x (i) se define como:
Inserte la descripción de la imagen aquí
necesitamos calcular la entropía del conjunto de datos, por lo que solo necesitamos sumar y sumar:
Inserte la descripción de la imagen aquí
hasta ahora, Podemos analizar cómo cambia la entropía del conjunto de datos antes y después de la división.

Calcule la entropía de Shannon del conjunto de datos

from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {
    
    }
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2) #log base 2
    return shannonEnt

Inserte la descripción de la imagen aquí
Cuanto mayor sea la entropía, más datos mezclados. Podemos agregar más categorías al conjunto de datos y observar los cambios en la entropía del conjunto de datos
Inserte la descripción de la imagen aquí

2.2 División del conjunto de datos

Divida el conjunto de datos de acuerdo con las características dadas, el primer parámetro es el conjunto de datos que se dividirá, el segundo parámetro es la característica del conjunto de datos dividido y el tercer parámetro es el valor de característica que debe devolverse

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

Inserte la descripción de la imagen aquí
Ejemplo: Divida el conjunto de datos según la característica 0 y devuelva los registros de datos con la característica 0 como 1

2.3 Elección del mejor método de división de conjuntos de datos

En la sección anterior, aprendimos cómo medir el nivel de confusión en un conjunto de datos. Nuestro objetivo: dividir continuamente el conjunto de datos en pequeñas ramas de acuerdo con las leyes internas de los datos, hasta que se clasifiquen los nodos hoja.

Calcularemos la entropía una vez para el conjunto de datos dividido por cada característica, y luego juzgaremos si la característica es la mejor característica de división del conjunto de datos (es decir, un proceso de decisión)

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature 

Aquí, cada característica se utiliza como base para dividir el conjunto de datos, y luego se obtiene la entropía newEntropía del conjunto de datos dividido . Cuando la entropía del nuevo conjunto de datos es la más pequeña, es decir, cuando se obtiene la máxima ganancia de información antes y después de seleccionar la característica para dividir el conjunto de datos, Se reduce la confusión de los datos y esta función es la mejor función de división de conjuntos de datos.
Inserte la descripción de la imagen aquí

2.4 Construir árboles de decisión de forma recursiva

El signo del final de la división del conjunto de datos anterior es: cada registro se asigna a un nodo hoja, cada vez que se selecciona una característica, el conjunto de datos se divide en varios conjuntos de datos más pequeños, solo necesitamos continuar con el conjunto de datos de rama dividido Use la función chooseBestFeatureToSplit para continuar dividiendo el conjunto de datos, y la recursividad simplemente puede implementar este proceso.
En la operación real, también podemos especificar el número de nodos hoja (es decir, el número de categorías, algoritmo C4.5 y CART)

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {
    
    bestFeatLabel:{
    
    }}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

Aquí, el funcionamiento de la rama del árbol se simula mediante el anidamiento del diccionario.
Inserte la descripción de la imagen aquí

3. Dibuja un árbol de decisiones

Si desea dibujar un árbol atractivo, debe calcular la profundidad y el ancho del árbol. Con las funciones de matplotlib, el proceso es demasiado engorroso. Aquí puede llamar directamente a == treePlotter.createPlot (lensestree) == para dibujar un hermoso árbol de decisiones

'''
Created on Oct 14, 2010

@author: Peter Harrington
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree)[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key], cntrPt, str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees = [{
    
    'no surfacing': {
    
    0: 'no', 1: {
    
    'flippers': {
    
    0: 'no', 1: 'yes'}}}},
                   {
    
    'no surfacing': {
    
    0: 'no', 1: {
    
    'flippers': {
    
    0: {
    
    'head': {
    
    0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees,[i]

# createPlot(thisTree)

Código de prueba

if __name__ == '__main__':
    mydat,labels=createDataSet()
    thistree = createTree(mydat,labels)
    treePlotter.createPlot(thistree)

Patrón de prueba:

Inserte la descripción de la imagen aquí

4. Una rutina práctica

Utilice el árbol de decisiones para predecir el tipo de lente de contacto. El
código fuente del conjunto de datos está disponible. Si necesita comentar y dejar un mensaje, también puede enviarme un correo electrónico [email protected]

Leyenda de resultados experimentales:

Inserte la descripción de la imagen aquí

5. Resumen del experimento

El núcleo del árbol de decisiones es el proceso de toma de decisiones. Usamos la entropía de Shannon de la teoría de la información para medir el grado de confusión de los datos divididos y luego obtener la base para la toma de decisiones. El algoritmo ID3 usa una sola característica para dividir hasta que se divide en nodos hoja. Los algoritmos C4.5 y CART son actualmente populares y llenarán el vacío en el futuro, jaja.
El problema de coincidencia excesiva del árbol de decisiones se puede realizar cortando el árbol de decisiones para mejorar la capacidad de generalización del algoritmo y llenar el hueco más tarde, jaja.

Supongo que te gusta

Origin blog.csdn.net/ca___0/article/details/109606671
Recomendado
Clasificación