Decision tree visualization xgb, lgb visualization

Table of contents

Guide package

Decision tree visualization method

 xgb, lgb modeling

 xgbvisualization

 lgbvisualization


Guide package

from sklearn.datasets import load_iris
from sklearn.tree import plot_tree
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

Decision tree visualization method

1. No need to install graphviz software package

        In view of the fact that actual modeling applications are often xgb, lgb, etc., decision trees are more used for analysis, so decision tree modeling and visualization are packaged here for convenience and practicality

def DecisionTree_plot(x,y,feature_names=None,target_names=None,max_depth=3,min_samples_leaf=10):
    clf = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_samples_leaf).fit(x,y)
    plt.figure(dpi=100,figsize=(8,8))
    plot_tree(clf, filled=True,rounded=True,feature_names=feature_names,  
                         class_names=target_names)
    plt.show()
iris = load_iris()
DecisionTree_plot(iris.data,iris.target,iris.feature_names,iris.target_names)

         As you can see, the results visualized using native methods are still relatively crude.

2. Decision tree visualization method 2 requires the graphviz software package to be installed.

import graphviz 
def DecisionTree_plot2(x,y,feature_names=None,target_names=None,max_depth=3,min_samples_leaf=10):
    clf = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_samples_leaf).fit(x,y)
    dot_data = tree.export_graphviz(clf,
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
    graph = graphviz.Source(dot_data)  
    return graph 
DecisionTree_plot2(iris.data,iris.target,iris.feature_names,iris.target_names)

 

 xgb, lgb modeling

import xgboost as xgb
import pandas as pd
import numpy as np

df=pd.read_excel('data.xlsx')
xgb_df = xgb.DMatrix(df.drop('y',axis=1), label =df.y)
lgb_df =lgb.Dataset(df.drop('y',axis=1), label =df.y)

param = {'max_depth':3, 'eta':0.2, 'min_child_weight':50}
xgb_model = xgb.train(param, dtrain)
lgb_model = lgb.train(param, dtrain)

 xgbvisualization

        where num_trees is the index of the subtree

xgb.to_graphviz(xgb_model, num_trees=0, rankdir='UT')

 lgbvisualization

lgb.create_tree_digraph(model, tree_index=0,encoding='UTF-8')

 

 

Follow the public account Python risk control model and data analysis for more knowledge and code sharing

Excellent articles from past issues

Detailed explanation of Catboost principle

Detailed explanation and actual combat of Catboost parameters

Model interpretability-shap value principle and practice

pyecarts dynamic interactive chart-big screen visualization

pyecharts dynamic chart embedded in ppt

xgboost principle (easy to understand without derivation)

lgb practical risk control algorithm competition

Guess you like

Origin blog.csdn.net/a7303349/article/details/126167585