Table of contents
Decision tree visualization method
Guide package
from sklearn.datasets import load_iris
from sklearn.tree import plot_tree
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
Decision tree visualization method
1. No need to install graphviz software package
In view of the fact that actual modeling applications are often xgb, lgb, etc., decision trees are more used for analysis, so decision tree modeling and visualization are packaged here for convenience and practicality
def DecisionTree_plot(x,y,feature_names=None,target_names=None,max_depth=3,min_samples_leaf=10):
clf = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_samples_leaf).fit(x,y)
plt.figure(dpi=100,figsize=(8,8))
plot_tree(clf, filled=True,rounded=True,feature_names=feature_names,
class_names=target_names)
plt.show()
iris = load_iris()
DecisionTree_plot(iris.data,iris.target,iris.feature_names,iris.target_names)
As you can see, the results visualized using native methods are still relatively crude.
2. Decision tree visualization method 2 requires the graphviz software package to be installed.
import graphviz
def DecisionTree_plot2(x,y,feature_names=None,target_names=None,max_depth=3,min_samples_leaf=10):
clf = DecisionTreeClassifier(max_depth=max_depth,min_samples_leaf=min_samples_leaf).fit(x,y)
dot_data = tree.export_graphviz(clf,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
return graph
DecisionTree_plot2(iris.data,iris.target,iris.feature_names,iris.target_names)
xgb, lgb modeling
import xgboost as xgb
import pandas as pd
import numpy as np
df=pd.read_excel('data.xlsx')
xgb_df = xgb.DMatrix(df.drop('y',axis=1), label =df.y)
lgb_df =lgb.Dataset(df.drop('y',axis=1), label =df.y)
param = {'max_depth':3, 'eta':0.2, 'min_child_weight':50}
xgb_model = xgb.train(param, dtrain)
lgb_model = lgb.train(param, dtrain)
xgbvisualization
where num_trees is the index of the subtree
xgb.to_graphviz(xgb_model, num_trees=0, rankdir='UT')
lgbvisualization
lgb.create_tree_digraph(model, tree_index=0,encoding='UTF-8')
Follow the public account Python risk control model and data analysis for more knowledge and code sharing
Excellent articles from past issues
Detailed explanation of Catboost principle
Detailed explanation and actual combat of Catboost parameters
Model interpretability-shap value principle and practice
pyecarts dynamic interactive chart-big screen visualization
pyecharts dynamic chart embedded in ppt