ツール シリーズ: TensorFlow デシジョン フォレスト_(7) デシジョン フォレスト モデルの確認とデバッグ


この記事では、モデルの構造を直接調べて作成する方法を学びます。初級レベルと中級レベルで紹介された概念についてはすでによく理解していることを前提としています。

この記事では次のことを行います。

  1. ランダム フォレスト モデルをトレーニングし、プログラムでその構造にアクセスします。

  2. ランダム フォレスト モデルを手動で作成し、クラシック モデルとして使用します。

設定

# 安装 TensorFlow Decision Forests 库
!pip install tensorflow_decision_forests

# 安装 wurlitzer 库,用于显示训练日志
!pip install wurlitzer
Collecting tensorflow_decision_forests
  Using cached tensorflow_decision_forests-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.2 MB)
Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.37.1)
Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.24.0rc2)
Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.3.0)
Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)
Collecting wurlitzer
  Using cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)
Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.5.2)
Requirement already satisfied: tensorflow~=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (65.6.3)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.4.0)
Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.7.0)
Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (14.0.6)
Requirement already satisfied: flatbuffers>=2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (22.12.6)
Requirement already satisfied: tensorboard<2.12,>=2.11 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (4.4.0)
Requirement already satisfied: keras<2.12,>=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.14.1)
Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.6.3)
Requirement already satisfied: tensorflow-estimator<2.12,>=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: protobuf<3.20,>=3.9.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.19.6)
Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.3.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.51.1)
Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (22.0)
Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.28.0)
Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.2.0)
Requirement already satisfied: python-dateutil>=2.8.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2022.6)
Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.2.2)
Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.4.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.6.1)
Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.28.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.4.6)
Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.15.0)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (5.2.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (4.9)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.3.0rc1)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (5.1.0)
Requirement already satisfied: charset-normalizer<3,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2022.12.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.26.13)
Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.11.0)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.5.0rc2)
Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.2.2)
Installing collected packages: wurlitzer, tensorflow_decision_forests
Successfully installed tensorflow_decision_forests-1.1.0 wurlitzer-3.0.3
Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)
# 导入tensorflow_decision_forests库
import tensorflow_decision_forests as tfdf

# 导入os、numpy、pandas、tensorflow、matplotlib.pyplot、math、collections库
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import collections

2022-12-14 12:24:51.050867: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:24:51.050964: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:24:51.050973: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

非表示のコード セルにより、colab での出力の高さが制限されます。


# 导入所需的模块
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
from IPython.display import display as ipy_display

# 定义一个魔术命令,用于设置单元格的最大高度
@register_line_magic
def set_cell_height(size):
  # 调用Javascript代码,设置单元格的最大高度
  ipy_display(
      Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
                 str(size) + "})"))

単純なランダムフォレストのトレーニング

初心者向けの colabのようにランダムフォレストを訓練します

# 下载数据集
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# 将数据集加载到Pandas Dataframe中
dataset_df = pd.read_csv("/tmp/penguins.csv")

# 显示前三个示例
print(dataset_df.head(3))

# 将Pandas Dataframe转换为tf数据集
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")

# 训练随机森林模型
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
  species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \
0  Adelie  Torgersen            39.1           18.7              181.0   
1  Adelie  Torgersen            39.5           17.4              186.0   
2  Adelie  Torgersen            40.3           18.0              195.0   

   body_mass_g     sex  year  
0       3750.0    male  2007  
1       3800.0  female  2007  
2       3250.0  female  2007  
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpvr7urazn as temporary training directory
Reading training dataset...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Training dataset read in 0:00:02.961832. Found 344 examples.
Training model...
Model trained in 0:00:00.093680
Compiling model...


[INFO 2022-12-14T12:24:58.955519768+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpvr7urazn/model/ with prefix fb8057db01324481
[INFO 2022-12-14T12:24:58.971817533+00:00 abstract_model.cc:1306] Engine "RandomForestGeneric" built
[INFO 2022-12-14T12:24:58.97187255+00:00 kernel.cc:1021] Use fast generic engine


WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f9b54f644c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert


WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f9b54f644c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert


WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f9b54f644c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.





<keras.callbacks.History at 0x7f9b5394c6d0>

モデル コンストラクターのcompute_oob_variable_importances=Trueハイパーパラメーターに注意してください。このオプションは、トレーニング中に Out-of-bag (OOB) 変数の重要性を計算します。これは、ランダム フォレスト モデルでよく使用される置換変数の重要性です。

OOB 変数の重要性の計算は最終モデルには影響しませんが、大規模なデータセットでのトレーニングの速度が低下します。

モデルの概要を確認してください。

# 打印模型的概述信息
model.summary()
<IPython.core.display.Javascript object>


Model: "random_forest_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
	bill_depth_mm
	bill_length_mm
	body_mass_g
	flipper_length_mm
	island
	sex
	year

No weights

Variable Importance: MEAN_DECREASE_IN_ACCURACY:
    1.    "bill_length_mm"  0.151163 ################
    2.            "island"  0.008721 #
    3.     "bill_depth_mm"  0.000000 
    4.       "body_mass_g"  0.000000 
    5.               "sex"  0.000000 
    6.              "year"  0.000000 
    7. "flipper_length_mm" -0.002907 

Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:
    1.    "bill_length_mm"  0.083305 ################
    2.            "island"  0.007664 #
    3. "flipper_length_mm"  0.003400 
    4.     "bill_depth_mm"  0.002741 
    5.       "body_mass_g"  0.000722 
    6.               "sex"  0.000644 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:
    1.    "bill_length_mm"  0.508510 ################
    2.            "island"  0.023487 
    3.     "bill_depth_mm"  0.007744 
    4. "flipper_length_mm"  0.006008 
    5.       "body_mass_g"  0.003017 
    6.               "sex"  0.001537 
    7.              "year" -0.000245 

Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:
    1.            "island"  0.002192 ################
    2.    "bill_length_mm"  0.001572 ############
    3.     "bill_depth_mm"  0.000497 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000890 

Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:
    1.    "bill_length_mm"  0.071306 ################
    2.            "island"  0.007299 #
    3. "flipper_length_mm"  0.004506 #
    4.     "bill_depth_mm"  0.002124 
    5.       "body_mass_g"  0.000548 
    6.               "sex"  0.000480 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:
    1.    "bill_length_mm"  0.108642 ################
    2.            "island"  0.014493 ##
    3.     "bill_depth_mm"  0.007406 #
    4. "flipper_length_mm"  0.005195 
    5.       "body_mass_g"  0.001012 
    6.               "sex"  0.000480 
    7.              "year" -0.000053 

Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:
    1.            "island"  0.002126 ################
    2.    "bill_length_mm"  0.001393 ###########
    3.     "bill_depth_mm"  0.000293 #####
    4.               "sex"  0.000000 ###
    5.              "year"  0.000000 ###
    6.       "body_mass_g" -0.000037 ###
    7. "flipper_length_mm" -0.000550 

Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:
    1.    "bill_length_mm"  0.083122 ################
    2.            "island"  0.010887 ##
    3. "flipper_length_mm"  0.003425 
    4.     "bill_depth_mm"  0.002731 
    5.       "body_mass_g"  0.000719 
    6.               "sex"  0.000641 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:
    1.    "bill_length_mm"  0.497611 ################
    2.            "island"  0.024045 
    3.     "bill_depth_mm"  0.007734 
    4. "flipper_length_mm"  0.006017 
    5.       "body_mass_g"  0.003000 
    6.               "sex"  0.001528 
    7.              "year" -0.000243 

Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:
    1.            "island"  0.002187 ################
    2.    "bill_length_mm"  0.001568 ############
    3.     "bill_depth_mm"  0.000495 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000886 

Variable Importance: MEAN_MIN_DEPTH:
    1.           "__LABEL"  3.479602 ################
    2.              "year"  3.463891 ###############
    3.               "sex"  3.430498 ###############
    4.       "body_mass_g"  2.898112 ###########
    5.            "island"  2.388925 ########
    6.     "bill_depth_mm"  2.336100 #######
    7.    "bill_length_mm"  1.282960 
    8. "flipper_length_mm"  1.270079 

Variable Importance: NUM_AS_ROOT:
    1. "flipper_length_mm" 157.000000 ################
    2.    "bill_length_mm" 76.000000 #######
    3.     "bill_depth_mm" 52.000000 #####
    4.            "island" 12.000000 
    5.       "body_mass_g"  3.000000 

Variable Importance: NUM_NODES:
    1.    "bill_length_mm" 778.000000 ################
    2.     "bill_depth_mm" 463.000000 #########
    3. "flipper_length_mm" 414.000000 ########
    4.            "island" 342.000000 ######
    5.       "body_mass_g" 338.000000 ######
    6.               "sex" 36.000000 
    7.              "year" 19.000000 

Variable Importance: SUM_SCORE:
    1.    "bill_length_mm" 36515.793787 ################
    2. "flipper_length_mm" 35120.434174 ###############
    3.            "island" 14669.408395 ######
    4.     "bill_depth_mm" 14515.446617 ######
    5.       "body_mass_g" 3485.330881 #
    6.               "sex" 354.201073 
    7.              "year" 49.737758 



Winner takes all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.0678223
Number of trees: 300
Total number of nodes: 5080

Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12)  6   2.00%   2.00% #
[ 12, 13)  0   0.00%   2.00%
[ 13, 14) 46  15.33%  17.33% #####
[ 14, 15)  0   0.00%  17.33%
[ 15, 16) 70  23.33%  40.67% ########
[ 16, 17)  0   0.00%  40.67%
[ 17, 18) 84  28.00%  68.67% ##########
[ 18, 19)  0   0.00%  68.67%
[ 19, 20) 46  15.33%  84.00% #####
[ 20, 21)  0   0.00%  84.00%
[ 21, 22) 30  10.00%  94.00% ####
[ 22, 23)  0   0.00%  94.00%
[ 23, 24) 13   4.33%  98.33% ##
[ 24, 25)  0   0.00%  98.33%
[ 25, 26)  2   0.67%  99.00%
[ 26, 27)  0   0.00%  99.00%
[ 27, 28)  2   0.67%  99.67%
[ 28, 29)  0   0.00%  99.67%
[ 29, 30)  0   0.00%  99.67%
[ 30, 31]  1   0.33% 100.00%

Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545  20.26%  20.26% ######
[ 3, 4) 747  27.77%  48.03% ########
[ 4, 5) 888  33.01%  81.04% ##########
[ 5, 6) 444  16.51%  97.55% #####
[ 6, 7)  62   2.30%  99.85% #
[ 7, 7]   4   0.15% 100.00%

Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[   5,  12) 1474  54.80%  54.80% ##########
[  12,  20)  124   4.61%  59.41% #
[  20,  27)   48   1.78%  61.19%
[  27,  35)   74   2.75%  63.94% #
[  35,  42)   58   2.16%  66.10%
[  42,  50)   85   3.16%  69.26% #
[  50,  57)   96   3.57%  72.83% #
[  57,  65)   87   3.23%  76.06% #
[  65,  72)   49   1.82%  77.88%
[  72,  80)   23   0.86%  78.74%
[  80,  88)   30   1.12%  79.85%
[  88,  95)   23   0.86%  80.71%
[  95, 103)   42   1.56%  82.27%
[ 103, 110)   62   2.30%  84.57%
[ 110, 118)  115   4.28%  88.85% #
[ 118, 125)  115   4.28%  93.12% #
[ 125, 133)   98   3.64%  96.77% #
[ 133, 140)   49   1.82%  98.59%
[ 140, 148)   31   1.15%  99.74%
[ 148, 155]    7   0.26% 100.00%

Attribute in nodes:
	778 : bill_length_mm [NUMERICAL]
	463 : bill_depth_mm [NUMERICAL]
	414 : flipper_length_mm [NUMERICAL]
	342 : island [CATEGORICAL]
	338 : body_mass_g [NUMERICAL]
	36 : sex [CATEGORICAL]
	19 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
	157 : flipper_length_mm [NUMERICAL]
	76 : bill_length_mm [NUMERICAL]
	52 : bill_depth_mm [NUMERICAL]
	12 : island [CATEGORICAL]
	3 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 1:
	250 : bill_length_mm [NUMERICAL]
	244 : flipper_length_mm [NUMERICAL]
	183 : bill_depth_mm [NUMERICAL]
	170 : island [CATEGORICAL]
	53 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
	462 : bill_length_mm [NUMERICAL]
	320 : flipper_length_mm [NUMERICAL]
	310 : bill_depth_mm [NUMERICAL]
	287 : island [CATEGORICAL]
	162 : body_mass_g [NUMERICAL]
	9 : sex [CATEGORICAL]
	5 : year [NUMERICAL]

Attribute in nodes with depth <= 3:
	669 : bill_length_mm [NUMERICAL]
	410 : bill_depth_mm [NUMERICAL]
	383 : flipper_length_mm [NUMERICAL]
	328 : island [CATEGORICAL]
	286 : body_mass_g [NUMERICAL]
	32 : sex [CATEGORICAL]
	10 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
	778 : bill_length_mm [NUMERICAL]
	462 : bill_depth_mm [NUMERICAL]
	413 : flipper_length_mm [NUMERICAL]
	342 : island [CATEGORICAL]
	338 : body_mass_g [NUMERICAL]
	36 : sex [CATEGORICAL]
	19 : year [NUMERICAL]

Condition type in nodes:
	2012 : HigherCondition
	378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
	288 : HigherCondition
	12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
	730 : HigherCondition
	170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
	1259 : HigherCondition
	296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
	1758 : HigherCondition
	360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
	2010 : HigherCondition
	378 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
	trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727
	trees: 13, Out-of-bag evaluation: accuracy:0.959064 logloss:0.4869
	trees: 31, Out-of-bag evaluation: accuracy:0.95614 logloss:0.284603
	trees: 54, Out-of-bag evaluation: accuracy:0.973837 logloss:0.175283
	trees: 73, Out-of-bag evaluation: accuracy:0.97093 logloss:0.175816
	trees: 85, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171781
	trees: 96, Out-of-bag evaluation: accuracy:0.97093 logloss:0.077417
	trees: 116, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0761788
	trees: 127, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0745239
	trees: 137, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0753508
	trees: 150, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0741464
	trees: 160, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0749481
	trees: 170, Out-of-bag evaluation: accuracy:0.979651 logloss:0.0719624
	trees: 190, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0711787
	trees: 203, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0701121
	trees: 213, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0682979
	trees: 224, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689686
	trees: 248, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0674086
	trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068218
	trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0680733
	trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0685965
	trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0683421
	trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0678223

変数の重要度には複数の名前があることに注意してくださいMEAN_DECREASE_IN_*

モデルの描画

次にモデルを描きます。

Random Forest は巨大なモデルです (モデルには 300 のツリーと約 5,000 のノードがあります。上記の概要を参照してください)。したがって、最初のツリーのみが描画され、ノードは深さ 3 に制限されます。

# 使用model_plotter模块中的plot_model_in_colab函数来绘制模型
# 参数model表示要绘制的模型
# 参数tree_idx表示要绘制的树的索引,这里设置为0表示绘制第一棵树
# 参数max_depth表示要绘制的树的最大深度,这里设置为3表示绘制到第三层
tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

/**

  • TF-DF によって生成された決定木のプロット。
  • ツリーはノード オブジェクトの再帰的な構造です。
  • ノードには、次のコンポーネントが 1 つ以上含まれています。
    • 値: ノードの出力を表します。ノードがリーフではない場合、
  •  the value is only present for analysis i.e. it is not used for
    
  •  predictions.
    
    • 条件 : 非リーフ ノードの場合、条件 (分割とも呼ばれます)
  •  defines a binary test to branch to the positive or negative child.
    
    • 説明: 一般に、ラベル間の関係を示すプロット
  •  and the condition to give insights about the effect of the condition.
    
    • 2 つの子 : 非リーフ ノードの場合、子ノード。最初
  •  children (i.e. "node.children[0]") is the negative children (drawn in
    
  •  red). The second children is the positive one (drawn in green).
    

*/

/**

  • 単一の決定木を DOM 要素にプロットします。
  • @param {!options} options 設定のディクショナリ。
  • @param {!tree} raw_tree 再帰的なツリー構造。
  • @param {string} Canvas_id 出力 dom 要素の ID。
    */
    function display_tree(options, raw_tree, Canvas_id) { console.log(options);

// ノードの配置を決定します。
consttree_struct = d3.tree().nodeSize(
[options.node_y_offset, options.node_x_offset])(d3.hierarchy(raw_tree));

// ノード配置の境界。
x_min = 無限大とします。
x_max = -x_min とします。
y_min = 無限大とします。
y_max = -x_min とします。

Tree_struct.each(d => { if (dx > x_max) x_max = dx; if (dx < x_min) x_min = dx; if (dy > y_max) y_max = dy; if (dy < y_min) y_min = dy; }) ;




// プロットのサイズ。
const width = y_max - y_min + options.node_x_size + options.margin * 2;
const height = x_max - x_min + options.node_y_size + options.margin * 2 +
options.node_y_offset - options.node_y_size;

const プロット = d3.select(canvas_id);

// ツールチップ
options.tooltip = Lot.append('div')
.attr('width', 100)
.attr('height', 100)
.style('padding', '4px')
.style('background ', '#fff')
.style('box-shadow', '4px 4px 0px rgba(0,0,0,0.1)')
.style('border', '1px ソリッドブラック')
.style('font -family', 'sans-serif')
.style('font-size', options.font_size)
.style('position', 'absolute')
.style('z-index', '10')
.attr( 'pointer-events', 'none')
.style('display', 'none');

// キャンバスを作成します
const svg = Lot.append('svg').attr('width', width).attr('height', height);
const chart =
svg.style('overflow', 'visible')
.append('g')
.attr('font-family', 'sans-serif')
.attr('font-size', options.font_size)
.attr(
'変換',
() => translate(${options.margin},${ - x_min + options.node_y_offset / 2 + options.margin}));

// バウンディングボックスをプロットします。
if (options.show_plot_bounding_box) { svg.append('rect') .attr('width', width) .attr('height', height) .attr('fill', 'none') .attr('ストローク-幅', 1.0) .attr('ストローク', '黒'); }






// エッジを描画します。
display_edges(オプション、グラフ、ツリー構造);

// ノードを描画します。
display_nodes(オプション、グラフ、ツリー構造);
}

/**

  • ツリーのノードを描画します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!graph} グラフを含むグラフ D3 検索ハンドル。
  • @param {!tree_struct} Tree_struct ツリーの構造 (ノードの配置、
  • data, etc.).
    

*/
function display_nodes(options,graph,tree_struct) { const nodes =graph.append('g') .selectAll('g') .data(tree_struct.descendants()) .join('g') .attr('変換', d => );




translate(${d.y},${d.x})

nodes.append('rect')
.attr('x', 0.5)
.attr('y', 0.5)
.attr('width', options.node_x_size)
.attr('height', options.node_y_size)
.attr ('ストローク', 'ライトグレー')
.attr('ストローク幅', 1)
.attr('塗り', '白')
.attr('y', -options.node_y_size / 2);

// 子のない条件ノードの右側にある括弧。
non_leaf_node_without_children =
nodes.filter(node => node.data.condition != null && node.children == null)
.append('g')
.attr('transform', translate(${options.node_x_size},0));

non_leaf_node_without_children.append('path')
.attr('d', 'M0,0 C 10,0 0,10 10,10')
.attr('fill', 'none')
.attr('ストローク幅' 、1.0)
.attr('ストローク', '#F00');

non_leaf_node_without_children.append('path')
.attr('d', 'M0,0 C 10,0 0,-10 10,-10')
.attr('fill', 'none')
.attr('ストローク-幅', 1.0)
.attr('ストローク', '#0F0');

const node_content = nodes.append('g').attr(
'transform',
translate(0,${options.node_padding - options.node_y_size / 2}));

node_content.append(node => create_node_element(オプション, ノード));
}

/**

  • 単一ノードの D3 コンテンツを作成します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!node} ノード 描画するノード。
  • @return {!d3} D3 コンテンツ。
    */
    function create_node_element(options, node) { // 出力アキュムレータ。let Output = { // 描画するコンテンツ。content: d3.create('svg:g'), // 次に描画する要素までの垂直オフセット。垂直オフセット: 0 };






// 条件。
if (node.data.condition != null) { display_condition(オプション、node.data.condition、出力); }

// 値。
if (node.data.value != null) { display_value(オプション、node.data.value、出力); }

// 説明。
if (node.data.explanation != null) { display_explanation(オプション、node.data.explanation、出力); }

出力.content.node()を返します;
}

/**

  • ノード内に 1 行のテキストを追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {string} text 表示するテキスト。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_node_text(options, text,output) { output.content.append('text') .attr('x', options.node_padding) .attr('y', Output.vertical_offset) .attr('alignment-ベースライン', 'ハンギング') .text(text); 出力.垂直オフセット += 10; }






/**

  • ツールヒントを含む 1 行のテキストをノード内に追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {string} text 表示するテキスト。
  • @param {string} ツールチップ ツールチップのテキスト。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_node_text_with_tooltip(options, text,tooltip,output) { const item = Output.content.append('text') .attr('x', options.node_padding) .attr('alignment-baseline', 'hanging' ) .text(テキスト);



add_tooltip(オプション, 項目, () => ツールチップ);
出力.垂直オフセット += 10;
}

/**

  • dom 要素にツールチップを追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!dom} ツールチップを装備する対象の Dom 要素。
  • @param {!func} get_content ツールヒントの HTML コンテンツを生成します。
    */
    function add_tooltip(options, target, get_content) { function show(d) { options.tooltip.style('display', 'block'); options.tooltip.html(get_content()); }



関数 Hide(d) { options.tooltip.style('display', 'none'); }

関数 move(d) { options.tooltip.style('display', 'block'); options.tooltip.style('left', (d.pageX + 5) + 'px'); options.tooltip.style('top', d.pageY + 'px'); }



target.on('マウスオーバー', 表示);
target.on('マウスアウト', 非表示);
target.on('mousemove', 移動);
}

/**

  • ノード内に条件を追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!condition} 条件 表示する条件。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_condition(オプション, 条件, 出力) { threshold_format = d3.format('r');

if (condition.type === 'IS_MISSING') { display_node_text(オプション, , 出力); 戻る; }
${condition.attribute} is missing

if (condition.type === 'IS_TRUE') { display_node_text(オプション, , 出力); 戻る; }
${condition.attribute} is true

if (condition.type === 'NUMERICAL_IS_HIGHER_THAN') { format = d3.format('r'); display_node_text(オプション, ,出力); 戻る; }



${condition.attribute} >= ${threshold_format(condition.threshold)}


if (condition.type === 'CATEGORICAL_IS_IN') { display_node_text_with_tooltip(オプション, , , 出力); 戻る; }

${condition.attribute} in [...]
${condition.attribute} in [${condition.mask}]

if (condition.type === 'CATEGORICAL_SET_CONTAINS') { display_node_text_with_tooltip(オプション, , , 出力); 戻る; }

${condition.attribute} intersect [...]
${condition.attribute} intersect [${condition.mask}]

if (condition.type === 'NUMERICAL_SPARSE_OBLIQUE') { display_node_text_with_tooltip(オプション, , ,出力); 戻る; }

Sparse oblique split...
[${condition.attributes}]*[${condition.weights}]>=${ threshold_format(condition.threshold)}


display_node_text(
オプション, Non supported condition ${condition.type}, 出力);
}

/**

  • ノード内に値を追加します。

  • @param {!options} options 設定のディクショナリ。

  • @param {!value} value 表示する値。

  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_value(オプション, 値, 出力) { if (value.type === 'PROBABILITY') { const left_margin = 0; const right_margin = 50; const Lot_width = options.node_x_size - options.node_padding * 2 - left_margin - right_margin;




    let cusum = Array.from(d3.cumsum(value.distribution));
    cusum.unshift(0);
    const distribution_plot = Output.content.append('g').attr(
    'transform', translate(0,${output.vertical_offset + 0.5}));

    distribution_plot.selectAll('rect')
    .data(value.distribution)
    .join('rect')
    .attr('height', 10)
    .attr(
    'x',
    (d, i) =>
    (cusum[i] * プロット幅 + 左マージン + options.node_padding))
    .attr('幅', (d, i) => d * プロット幅)
    .style('fill', (d, i) => d3.schemeSet1[i]);

    const num_examples =
    Output.content.append('g')
    .attr('transform', translate(0,${output.vertical_offset}))
    .append('text')
    .attr('x', options.node_x_size - options.node_padding)
    .attr('alignment-baseline ', 'hanging')
    .attr('text-anchor', 'end')
    .text( (${value.num_examples}));

    const distribution_details = d3.create('ul');
    distribution_details.selectAll('li')
    .data(value.distribution)
    .join('li')
    .append('span')
    .text(
    (d, i) =>
    'クラス ' + i + ': ' + d3 .format('.3%')(value.distribution[i]));

    add_tooltip(オプション, distribution_plot, () => distribution_details.html());
    add_tooltip(options, num_examples, () => 'サンプルの数');

    出力.垂直オフセット += 10;
    戻る;
    }

if (value.type === 'REGRESSION') { display_node_text( options, 'value: ' + d3.format('r')(value.value) + + d3.format('.6')(value.num_examples ) + 出力); 戻る; }


(
)


display_node_text(オプション, Non supported value ${value.type}, 出力);
}

/**

  • ノード内に説明を追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!explanation} description 表示する説明。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_explanation(options, description, Output) { // 説明の前のマージン。出力.垂直オフセット += 10;

display_node_text(
オプション, Non supported explanation ${explanation.type}, 出力);
}

/**

  • 木の端を描きます。
  • @param {!options} options 設定のディクショナリ。
  • @param {!graph} グラフを含むグラフ D3 検索ハンドル。
  • @param {!tree_struct} Tree_struct ツリーの構造 (ノードの配置、
  • data, etc.).
    

*/
function display_edges(options,graph,tree_struct) { // 親ノードと子ノードの間にベジェでエッジを描画します。functiondraw_single_edge(d) { return 'M' + (d.source.y + options.node_x_size) + ',' + d.source.x + ' C' + (d.source.y + options.node_x_size + options. edge_rounding) + ',' + d.source.x + ' ' + (d.target.y - options.edge_rounding) + ',' + d.target.x + ' ' + d.target.y + ',' + d.target.x; }






graph.append('g')
.attr('fill', 'none')
.attr('ストローク幅', 1.2)
.selectAll('path')
.data(tree_struct.links())
.join(' path')
.attr('d',draw_single_edge)
.attr(
'ストローク', d => (d.target === d.source.children[0]) ? '#0F0' : '#F00');
}

display_tree({“margin”: 10, “node_x_size”: 160, “node_y_size”: 28, “node_x_offset”: 180, “node_y_offset”: 33, “font_size”: 10, “edge_rounding”: 20, “node_padding”: 2 , “show_plot_bounding_box”: false}, {“値”: {“タイプ”: “確率”, “分布”: [0.47093023255813954, 0.19476744186046513, 0.33430232558139533], “num_examples”: 344.0}, “条件” : {「タイプ」: “NUMERICAL_IS_HIGHER_THAN”, “属性”: “bill_length_mm”, “しきい値”: 43.25}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.005847953216374269, 0.3567251461988304, 0.6374 269005847953]、 "num_examples": 171.0}、"条件": {"タイプ": "CATEGORICAL_IS_IN"、"属性": "アイランド"、"マスク": ["ビスコー"]}、"子": [{"値": { 「タイプ」:「確率」、「配布」:[0.0090909090909090909、0.0、0.99090909090909091]、「num_examples」:110.0}、「条件」:{「タイプ」:「numerical_is_higher_higher_than」 ”: 17.225584030151367}, “子”: [{“値”: {“タイプ”: “確率”, “分布”: [0.16666666666666666, 0.0, 0.8333333333333334], “num_examples”: 6.0}}, {“値”: { "タイプ": "確率", "分布": [0.0, 0.0, 1.0], "num_examples": 104.0}}]}, {"値": {"タイプ": "確率", "分布": [0.0 、1.0、0.0]、“num_examples”: 61.0}}]}, {“値”: {“タイプ”: “確率”、“分布”: [0.930635838150289, 0.03468208092485549、0.03468208092485549]、“num_examples”: 173.0}、「条件": {"タイプ": "NUMERICAL_IS_HIGHER_THAN", "属性": "bill_ Depth_mm", "しきい値": 15.100000381469727}, "子": [{"値": {"タイプ": "確率", "分布": [0.9640718562874252, 0.03592814371257485, 0.0]、“num_examples”: 167.0}、“条件”: {“タイプ”: “NUMERICAL_IS_HIGHER_THAN”、“属性”: “flipper_length_mm”、“しきい値”: 187.5}、“子”: [{“値": {"タイプ": "確率", "分布": [1.0, 0.0, 0.0], "num_examples": 104.0}}, {"値": {"タイプ": "確率", "分布": [0.9047619047619048, 0.09523809523809523, 0.0]、“num_examples”: 63.0}、“条件”: {“タイプ”: “NUMERICAL_IS_HIGHER_THAN”、“属性”: “bill_length_mm”、“しきい値”: 42.30000305175 781}}]}, {“値” : {“タイプ”: “確率”, “分布”: [0.0, 0.0, 1.0], “num_examples”: 6.0}}]}]}, “#tree_plot_05707b35c4f748738efd3da21ab9197f”)

モデル構造を確認する

make_inspector()モデルの構造とメタデータは、作成されたインスペクターを通じて取得できます

**注意:** 学習アルゴリズムとハイパーパラメータに応じて、インスペクタはさまざまな特殊なプロパティを公開します。たとえば、winner_take_allフィールドはランダム フォレスト モデルに固有です。

# 创建一个模型检查器对象,用于检查模型的性能和质量
inspector = model.make_inspector()

このモデルの場合、使用可能なインスペクター フィールドは次のとおりです。

# 使用列表推导式,遍历inspector模块中的所有属性
# 过滤掉以"_"开头的属性
fields = [field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME',
 'dataspec',
 'evaluation',
 'export_to_tensorboard',
 'extract_all_trees',
 'extract_tree',
 'features',
 'header',
 'iterate_on_nodes',
 'label',
 'label_classes',
 'metadata',
 'model_type',
 'num_trees',
 'objective',
 'specialized_header',
 'task',
 'training_logs',
 'tuning_logs',
 'variable_importances',
 'winner_take_all_inference']

忘れずにAPI リファレンスを確認するか、?組み込みドキュメントの表示を使用してください。

?inspector.model_type

一部のモデルのメタデータ:

# 打印模型类型
print("Model type:", inspector.model_type())

# 打印模型中树的数量
print("Number of trees:", inspector.num_trees())

# 打印模型的目标函数
print("Objective:", inspector.objective())

# 打印模型的输入特征
print("Input features:", inspector.features())
Model type: RANDOM_FOREST
Number of trees: 300
Objective: Classification(label=__LABEL, class=None, num_classes=3)
Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]

evaluate()トレーニング中に計算されたモデルの評価です。この評価に使用されるデータセットはアルゴリズムによって異なります。たとえば、検証データセットや袋外データセットなどです。

**注意:** トレーニング中に計算されますが、evaluate()トレーニング データセットで評価されることはありません。

# 创建一个名为inspector的对象
inspector = Inspector()
# 调用inspector对象的evaluation()方法
inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06782230959804512, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)

変数の重要性は次のとおりです。

変数の重要度は次のとおりです。

# 打印可用的变量重要性
print(f"Available variable importances:")

# 遍历变量重要性字典的键,并打印出来
for importance in inspector.variable_importances().keys():
    print("\t", importance)
Available variable importances:
	 MEAN_DECREASE_IN_AP_1_VS_OTHERS
	 MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
	 SUM_SCORE
	 MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
	 MEAN_DECREASE_IN_ACCURACY
	 MEAN_DECREASE_IN_AUC_1_VS_OTHERS
	 MEAN_DECREASE_IN_AP_3_VS_OTHERS
	 NUM_AS_ROOT
	 MEAN_DECREASE_IN_AP_2_VS_OTHERS
	 MEAN_DECREASE_IN_AUC_2_VS_OTHERS
	 MEAN_MIN_DEPTH
	 MEAN_DECREASE_IN_AUC_3_VS_OTHERS
	 NUM_NODES
	 MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS

変数の重要度が異なれば、セマンティクスも異なります。たとえば、AUC が平均的に低下した特徴は0.05、トレーニング データセットから特徴を削除すると AUC が 5% 減少/低下することを意味します。

# 获取类别1与其他类别之间的AUC的平均减少量
mean_decrease_in_auc_1_vs_others = inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389),
 ("island" (4; #4), 0.007298519736842035),
 ("flipper_length_mm" (1; #3), 0.004505893640351366),
 ("bill_depth_mm" (1; #0), 0.0021244517543865804),
 ("body_mass_g" (1; #2), 0.0005482456140351033),
 ("sex" (4; #5), 0.00047971491228060437),
 ("year" (1; #6), 0.0)]

Matplotlib を使用してインスペクターで変数の重要性をプロットする

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))  # 创建一个大小为12x4的图形

# 平均AUC下降值(class 1相对于其他类别)
variable_importance_metric = "MEAN_DECREASE_IN_AUC_1_VS_OTHERS"
variable_importances = inspector.variable_importances()[variable_importance_metric]

# 提取特征名称和重要性值
#
# `variable_importances` 是一个包含<特征, 重要性>元组的列表
feature_names = [vi[0].name for vi in variable_importances]  # 提取特征名称
feature_importances = [vi[1] for vi in variable_importances]  # 提取重要性值
# 特征按重要性值降序排列
feature_ranks = range(len(feature_names))

bar = plt.barh(feature_ranks, feature_importances, label=[str(x) for x in feature_ranks])  # 创建水平条形图
plt.yticks(feature_ranks, feature_names)  # 设置y轴刻度为特征名称
plt.gca().invert_yaxis()  # 反转y轴刻度顺序,使重要性高的特征在上方

# TODO: 当可用时,替换为 "plt.bar_label()"
# 使用值标记每个条形图
for importance, patch in zip(feature_importances, bar.patches):
  plt.text(patch.get_x() + patch.get_width(), patch.get_y(), f"{
      
      importance:.4f}", va="top")

plt.xlabel(variable_importance_metric)  # 设置x轴标签为重要性度量
plt.title("Mean decrease in AUC of the class 1 vs the others")  # 设置图形标题
plt.tight_layout()  # 调整图形布局,以防止标签重叠
plt.show()  # 显示图形

最後に、実際のツリー構造にアクセスします。

# 从inspector对象中提取树的信息
# 参数tree_idx表示要提取的树的索引,这里为0表示提取第一棵树的信息
inspector.extract_tree(tree_idx=0)
Tree(root=NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True, score=0.5482327342033386), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True, score=0.6515106558799744), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False, score=0.027205035090446472), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0), idx=7), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0), idx=6), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0), idx=5), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True, score=0.150658518075943), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True, score=0.036139510571956635), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0), idx=4), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True, score=0.23430533707141876), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0), idx=3), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True, score=0.043961383402347565), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0), idx=2), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0), idx=1), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0), idx=0), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)), label_classes=None)

木の抽出は効率的ではありません。iterate_on_nodes()速度が重要な場合は、モデル チェックのメソッドを使用できます。このメソッドは、モデルのすべてのノードに対する深さ優先の事前順序走査反復子です。

注:extract_tree()これは をiterate_on_nodes()使用して実装されます。

次の例では、各機能が使用された回数をカウントします (構造変数の重要性の指標)。

# 创建一个默认字典number_of_use,用于记录每个特征在其条件中被使用的次数
number_of_use = collections.defaultdict(lambda: 0)

# 对所有节点进行深度优先的前序遍历
for node_iter in inspector.iterate_on_nodes():

  # 如果节点不是叶节点,则跳过
  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
    continue

  # 遍历节点条件中使用的所有特征
  # 默认情况下,模型是"oblique"的,即每个节点测试一个特征
  for feature in node_iter.node.condition.features():
    # 特征在使用次数上加1
    number_of_use[feature] += 1

# 打印每个特征的条件节点数
print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
  print("\t", feature.name, ":", count)
Number of condition nodes per features:
	 bill_length_mm : 778
	 bill_depth_mm : 463
	 flipper_length_mm : 414
	 island : 342
	 body_mass_g : 338
	 year : 19
	 sex : 36

モデルを手動で作成する

このセクションでは、小さなランダム フォレスト モデルを手動で作成します。さらに単純にするために、モデルには単純なツリーのみが含まれています。

3个标签类别:红色、蓝色和绿色。
2个特征:f1(数值型)和f2(字符串分类型)

f1>=1.5
    ├─(正)─ f2在["猫","狗"]中
    │         ├─(正)─ 值:[0.8, 0.1, 0.1]
    │         └─(负)─ 值:[0.1, 0.8, 0.1]
    └─(负)─ 值:[0.1, 0.1, 0.8]
# 创建模型构建器
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",  # 指定模型保存的路径
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color",  # 指定目标变量为"color"
        classes=["red", "blue", "green"]))  # 指定目标变量的类别为["red", "blue", "green"]

各ツリーは 1 つずつ追加されます。

注:ツリー オブジェクト ( ) はtfdf.py_tree.tree.Tree、前のセクションでextract_tree()返されたツリー オブジェクトと同じです。

# 导入所需的模块和类
Tree = tfdf.py_tree.tree.Tree  # 树结构
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec  # 列规范
ColumnType = tfdf.py_tree.dataspec.ColumnType  # 列类型
NonLeafNode = tfdf.py_tree.node.NonLeafNode  # 非叶节点
LeafNode = tfdf.py_tree.node.LeafNode  # 叶节点
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition  # 数值大于条件
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition  # 类别在条件
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue  # 概率值

# 创建树结构并添加到builder中
builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),  # 数值特征"f1"
                threshold=1.5,  # 阈值为1.5
                missing_evaluation=False),  # 不考虑缺失值
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),  # 类别特征"f2"
                    mask=["cat", "dog"],  # 类别为"cat"或"dog"
                    missing_evaluation=False),  # 不考虑缺失值
                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),  # 正向子节点为叶节点,概率值为[0.8, 0.1, 0.1],样本数为10
                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),  # 负向子节点为叶节点,概率值为[0.1, 0.8, 0.1],样本数为20
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))  # 负向子节点为叶节点,概率值为[0.1, 0.1, 0.8],样本数为30

エンドツリーの書き込み

# 关闭builder对象
builder.close()
[INFO 2022-12-14T12:25:00.790486355+00:00 kernel.cc:1175] Loading model from path /tmp/manual_model/tmp/ with prefix e09a067144bc479b
[INFO 2022-12-14T12:25:00.790802259+00:00 decision_forest.cc:640] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO 2022-12-14T12:25:00.790878962+00:00 kernel.cc:1021] Use fast generic engine
WARNING:absl:Found untraced functions such as call_get_leaves, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: /tmp/manual_model/assets


INFO:tensorflow:Assets written to: /tmp/manual_model/assets

これで、モデルを通常の keras モデルとして開き、予測を行うことができます。

# 加载预训练模型
manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO 2022-12-14T12:25:01.436506097+00:00 kernel.cc:1175] Loading model from path /tmp/manual_model/assets/ with prefix e09a067144bc479b
[INFO 2022-12-14T12:25:01.436871761+00:00 decision_forest.cc:640] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO 2022-12-14T12:25:01.436909696+00:00 kernel.cc:1021] Use fast generic engine
# 创建一个tf.data.Dataset对象,从给定的张量中切片得到数据集
# 数据集包含两个特征"f1"和"f2",分别是浮点数和字符串类型
# 数据集中的每个样本是一个字典,包含"f1"和"f2"两个键
# 样本数据为:
#   "f1": [1.0, 2.0, 3.0]
#   "f2": ["cat", "cat", "bird"]
# 使用batch(2)方法将数据集划分为大小为2的批次
examples = tf.data.Dataset.from_tensor_slices({
    
    
    "f1": [1.0, 2.0, 3.0],
    "f2": ["cat", "cat", "bird"]
}).batch(2)

# 使用manual_model对examples进行预测
predictions = manual_model.predict(examples)

# 打印预测结果
print("predictions:\n", predictions)
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
predictions:
 [[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.1 0.8 0.1]]

アクセス構造:

注:モデルはシリアル化および逆シリアル化されるため、代替の同等の形式を使用する必要があります。

# 代码注释

# 获取yggdrasil模型路径
yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)

# 创建一个模型检查器,用于检查模型的输入特征
inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/
Input features: ["f1" (1; #1), "f2" (4; #2)]

もちろん、この構築されたモデルを手動で描画することもできます。

# 导入tfdf库中的plot_model_in_colab函数
import tensorflow_decision_forests as tfdf

# 使用plot_model_in_colab函数绘制manual_model模型的结构图
tfdf.model_plotter.plot_model_in_colab(manual_model)

/**

  • TF-DF によって生成された決定木のプロット。
  • ツリーはノード オブジェクトの再帰的な構造です。
  • ノードには、次のコンポーネントが 1 つ以上含まれています。
    • 値: ノードの出力を表します。ノードがリーフではない場合、
  •  the value is only present for analysis i.e. it is not used for
    
  •  predictions.
    
    • 条件 : 非リーフ ノードの場合、条件 (分割とも呼ばれます)
  •  defines a binary test to branch to the positive or negative child.
    
    • 説明: 一般に、ラベル間の関係を示すプロット
  •  and the condition to give insights about the effect of the condition.
    
    • 2 つの子 : 非リーフ ノードの場合、子ノード。最初
  •  children (i.e. "node.children[0]") is the negative children (drawn in
    
  •  red). The second children is the positive one (drawn in green).
    

*/

/**

  • 単一の決定木を DOM 要素にプロットします。
  • @param {!options} options 設定のディクショナリ。
  • @param {!tree} raw_tree 再帰的なツリー構造。
  • @param {string} Canvas_id 出力 dom 要素の ID。
    */
    function display_tree(options, raw_tree, Canvas_id) { console.log(options);

// ノードの配置を決定します。
consttree_struct = d3.tree().nodeSize(
[options.node_y_offset, options.node_x_offset])(d3.hierarchy(raw_tree));

// ノード配置の境界。
x_min = 無限大とします。
x_max = -x_min とします。
y_min = 無限大とします。
y_max = -x_min とします。

Tree_struct.each(d => { if (dx > x_max) x_max = dx; if (dx < x_min) x_min = dx; if (dy > y_max) y_max = dy; if (dy < y_min) y_min = dy; }) ;




// プロットのサイズ。
const width = y_max - y_min + options.node_x_size + options.margin * 2;
const height = x_max - x_min + options.node_y_size + options.margin * 2 +
options.node_y_offset - options.node_y_size;

const プロット = d3.select(canvas_id);

// ツールチップ
options.tooltip = Lot.append('div')
.attr('width', 100)
.attr('height', 100)
.style('padding', '4px')
.style('background ', '#fff')
.style('box-shadow', '4px 4px 0px rgba(0,0,0,0.1)')
.style('border', '1px ソリッドブラック')
.style('font -family', 'sans-serif')
.style('font-size', options.font_size)
.style('position', 'absolute')
.style('z-index', '10')
.attr( 'pointer-events', 'none')
.style('display', 'none');

// キャンバスを作成します
const svg = Lot.append('svg').attr('width', width).attr('height', height);
const chart =
svg.style('overflow', 'visible')
.append('g')
.attr('font-family', 'sans-serif')
.attr('font-size', options.font_size)
.attr(
'変換',
() => translate(${options.margin},${ - x_min + options.node_y_offset / 2 + options.margin}));

// バウンディングボックスをプロットします。
if (options.show_plot_bounding_box) { svg.append('rect') .attr('width', width) .attr('height', height) .attr('fill', 'none') .attr('ストローク-幅', 1.0) .attr('ストローク', '黒'); }






// エッジを描画します。
display_edges(オプション、グラフ、ツリー構造);

// ノードを描画します。
display_nodes(オプション、グラフ、ツリー構造);
}

/**

  • ツリーのノードを描画します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!graph} グラフを含むグラフ D3 検索ハンドル。
  • @param {!tree_struct} Tree_struct ツリーの構造 (ノードの配置、
  • data, etc.).
    

*/
function display_nodes(options,graph,tree_struct) { const nodes =graph.append('g') .selectAll('g') .data(tree_struct.descendants()) .join('g') .attr('変換', d => );




translate(${d.y},${d.x})

nodes.append('rect')
.attr('x', 0.5)
.attr('y', 0.5)
.attr('width', options.node_x_size)
.attr('height', options.node_y_size)
.attr ('ストローク', 'ライトグレー')
.attr('ストローク幅', 1)
.attr('塗り', '白')
.attr('y', -options.node_y_size / 2);

// 子のない条件ノードの右側にある括弧。
non_leaf_node_without_children =
nodes.filter(node => node.data.condition != null && node.children == null)
.append('g')
.attr('transform', translate(${options.node_x_size},0));

non_leaf_node_without_children.append('path')
.attr('d', 'M0,0 C 10,0 0,10 10,10')
.attr('fill', 'none')
.attr('ストローク幅' 、1.0)
.attr('ストローク', '#F00');

non_leaf_node_without_children.append('path')
.attr('d', 'M0,0 C 10,0 0,-10 10,-10')
.attr('fill', 'none')
.attr('ストローク-幅', 1.0)
.attr('ストローク', '#0F0');

const node_content = nodes.append('g').attr(
'transform',
translate(0,${options.node_padding - options.node_y_size / 2}));

node_content.append(node => create_node_element(オプション, ノード));
}

/**

  • 単一ノードの D3 コンテンツを作成します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!node} ノード 描画するノード。
  • @return {!d3} D3 コンテンツ。
    */
    function create_node_element(options, node) { // 出力アキュムレータ。let Output = { // 描画するコンテンツ。content: d3.create('svg:g'), // 次に描画する要素までの垂直オフセット。垂直オフセット: 0 };






// 条件。
if (node.data.condition != null) { display_condition(オプション、node.data.condition、出力); }

// 値。
if (node.data.value != null) { display_value(オプション、node.data.value、出力); }

// 説明。
if (node.data.explanation != null) { display_explanation(オプション、node.data.explanation、出力); }

出力.content.node()を返します;
}

/**

  • ノード内に 1 行のテキストを追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {string} text 表示するテキスト。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_node_text(options, text,output) { output.content.append('text') .attr('x', options.node_padding) .attr('y', Output.vertical_offset) .attr('alignment-ベースライン', 'ハンギング') .text(text); 出力.垂直オフセット += 10; }






/**

  • ツールヒントを含む 1 行のテキストをノード内に追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {string} text 表示するテキスト。
  • @param {string} ツールチップ ツールチップのテキスト。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_node_text_with_tooltip(options, text,tooltip,output) { const item = Output.content.append('text') .attr('x', options.node_padding) .attr('alignment-baseline', 'hanging' ) .text(テキスト);



add_tooltip(オプション, 項目, () => ツールチップ);
出力.垂直オフセット += 10;
}

/**

  • dom 要素にツールチップを追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!dom} ツールチップを装備する対象の Dom 要素。
  • @param {!func} get_content ツールヒントの HTML コンテンツを生成します。
    */
    function add_tooltip(options, target, get_content) { function show(d) { options.tooltip.style('display', 'block'); options.tooltip.html(get_content()); }



関数 Hide(d) { options.tooltip.style('display', 'none'); }

関数 move(d) { options.tooltip.style('display', 'block'); options.tooltip.style('left', (d.pageX + 5) + 'px'); options.tooltip.style('top', d.pageY + 'px'); }



target.on('マウスオーバー', 表示);
target.on('マウスアウト', 非表示);
target.on('mousemove', 移動);
}

/**

  • ノード内に条件を追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!condition} 条件 表示する条件。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_condition(オプション, 条件, 出力) { threshold_format = d3.format('r');

if (condition.type === 'IS_MISSING') { display_node_text(オプション, , 出力); 戻る; }
${condition.attribute} is missing

if (condition.type === 'IS_TRUE') { display_node_text(オプション, , 出力); 戻る; }
${condition.attribute} is true

if (condition.type === 'NUMERICAL_IS_HIGHER_THAN') { format = d3.format('r'); display_node_text(オプション, ,出力); 戻る; }



${condition.attribute} >= ${threshold_format(condition.threshold)}


if (condition.type === 'CATEGORICAL_IS_IN') { display_node_text_with_tooltip(オプション, , , 出力); 戻る; }

${condition.attribute} in [...]
${condition.attribute} in [${condition.mask}]

if (condition.type === 'CATEGORICAL_SET_CONTAINS') { display_node_text_with_tooltip(オプション, , , 出力); 戻る; }

${condition.attribute} intersect [...]
${condition.attribute} intersect [${condition.mask}]

if (condition.type === 'NUMERICAL_SPARSE_OBLIQUE') { display_node_text_with_tooltip(オプション, , ,出力); 戻る; }

Sparse oblique split...
[${condition.attributes}]*[${condition.weights}]>=${ threshold_format(condition.threshold)}


display_node_text(
オプション, Non supported condition ${condition.type}, 出力);
}

/**

  • ノード内に値を追加します。

  • @param {!options} options 設定のディクショナリ。

  • @param {!value} value 表示する値。

  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_value(オプション, 値, 出力) { if (value.type === 'PROBABILITY') { const left_margin = 0; const right_margin = 50; const Lot_width = options.node_x_size - options.node_padding * 2 - left_margin - right_margin;




    let cusum = Array.from(d3.cumsum(value.distribution));
    cusum.unshift(0);
    const distribution_plot = Output.content.append('g').attr(
    'transform', translate(0,${output.vertical_offset + 0.5}));

    distribution_plot.selectAll('rect')
    .data(value.distribution)
    .join('rect')
    .attr('height', 10)
    .attr(
    'x',
    (d, i) =>
    (cusum[i] * プロット幅 + 左マージン + options.node_padding))
    .attr('幅', (d, i) => d * プロット幅)
    .style('fill', (d, i) => d3.schemeSet1[i]);

    const num_examples =
    Output.content.append('g')
    .attr('transform', translate(0,${output.vertical_offset}))
    .append('text')
    .attr('x', options.node_x_size - options.node_padding)
    .attr('alignment-baseline ', 'hanging')
    .attr('text-anchor', 'end')
    .text( (${value.num_examples}));

    const distribution_details = d3.create('ul');
    distribution_details.selectAll('li')
    .data(value.distribution)
    .join('li')
    .append('span')
    .text(
    (d, i) =>
    'クラス ' + i + ': ' + d3 .format('.3%')(value.distribution[i]));

    add_tooltip(オプション, distribution_plot, () => distribution_details.html());
    add_tooltip(options, num_examples, () => 'サンプルの数');

    出力.垂直オフセット += 10;
    戻る;
    }

if (value.type === 'REGRESSION') { display_node_text( options, 'value: ' + d3.format('r')(value.value) + + d3.format('.6')(value.num_examples ) + 出力); 戻る; }


(
)


display_node_text(オプション, Non supported value ${value.type}, 出力);
}

/**

  • ノード内に説明を追加します。
  • @param {!options} options 設定のディクショナリ。
  • @param {!explanation} description 表示する説明。
  • @param {!output} 出力 表示アキュムレータを出力します。
    */
    function display_explanation(options, description, Output) { // 説明の前のマージン。出力.垂直オフセット += 10;

display_node_text(
オプション, Non supported explanation ${explanation.type}, 出力);
}

/**

  • 木の端を描きます。
  • @param {!options} options 設定のディクショナリ。
  • @param {!graph} グラフを含むグラフ D3 検索ハンドル。
  • @param {!tree_struct} Tree_struct ツリーの構造 (ノードの配置、
  • data, etc.).
    

*/
function display_edges(options,graph,tree_struct) { // 親ノードと子ノードの間にベジェでエッジを描画します。functiondraw_single_edge(d) { return 'M' + (d.source.y + options.node_x_size) + ',' + d.source.x + ' C' + (d.source.y + options.node_x_size + options. edge_rounding) + ',' + d.source.x + ' ' + (d.target.y - options.edge_rounding) + ',' + d.target.x + ' ' + d.target.y + ',' + d.target.x; }






graph.append('g')
.attr('fill', 'none')
.attr('ストローク幅', 1.2)
.selectAll('path')
.data(tree_struct.links())
.join(' path')
.attr('d',draw_single_edge)
.attr(
'ストローク', d => (d.target === d.source.children[0]) ? '#0F0' : '#F00');
}

display_tree({“margin”: 10, “node_x_size”: 160, “node_y_size”: 28, “node_x_offset”: 180, “node_y_offset”: 33, “font_size”: 10, “edge_rounding”: 20, “node_padding”: 2 , “show_plot_bounding_box”: false, “labels”: “[“red”, “blue”, “green”]”}, {“condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “f1”, "しきい値": 1.5}、"子": [{"条件": {"タイプ": "CATEGORICAL_IS_IN"、"属性": "f2"、"マスク": ["猫"、"犬"]}、" Children": [{"value": {"type": "PROBABILITY", "distribution": [0.8, 0.1, 0.1], "num_examples": 10.0}}, {"value": {"type": "PROBABILITY" ”, “分布”: [0.1, 0.8, 0.1], “num_examples”: 20.0}}]}, {“値”: {“タイプ”: “確率”, “分布”: [0.1, 0.1, 0.8], “num_examples”: 30.0}}]}, “#tree_plot_34c8fb6cf7ca49eda845b971be7f0560”)

おすすめ

転載: blog.csdn.net/wjjc1017/article/details/135189646