決定木の枝刈り: モデルの過学習を解決する [決定木、機械学習]

枝刈りを通じてデシジョン ツリーの過学習問題を解決する方法

分类デシジョン ツリーは、問題を解決するために使用される強力な機械学習アルゴリズムです回归デシジョン ツリー モデルは、ツリー構造のデシジョン ルールを通じて予測を行いますが、デシジョン ツリーを構築する際に、モデルがトレーニング データに対しては良好なパフォーマンスを発揮しますが、目に見えないデータに対してはパフォーマンスが低下するという、過学習の問題が頻繁に発生します。

過剰適合の脅威

機械学習では、过拟合トレーニング データでは良好なパフォーマンスを発揮するモデルを指す一般的な問題ですが、目に見えないデータに一般化するとパフォーマンスが低下します。これは、デシジョン ツリーが各トレーニング サンプルをできるだけ正確に適合させようとする傾向があり、その結果、ツリーが複雑すぎて、真のデータ パターンだけでなく、トレーニング セット内のノイズやランダムな変動を捕捉することになるためです。

デシジョン ツリーの枝刈り: モデルの過学習を解決する

デシジョン ツリー プルーニングは、デシジョン ツリーの複雑さを軽減し、トレーニング データの過剰適合を防ぐ手法です。枝刈りの目的は、デシジョン ツリー (またはデシジョン ルール) の一部の分岐を削除してツリーの深さと複雑さを軽減し、それによってモデルの汎化能力を向上させることです。つまり、枝刈りは、トレーニング データ内の特定の状況への過剰適合を減らすことで、モデルのより幅広い適用性を実現します。

1.前庭剪定

事前枝刈りでは、ツリーが複雑になりすぎないように、デシジョン ツリー構築プロセス中にノードを分割する前に手順を実行します。事前枝刈りの方法には、ノードの分割に必要な最大深さ、リーフ ノードの最小数、またはサンプルの最小数の設定が含まれます。これらの条件付き制限により、ツリーの成長中に不要な分岐を回避できるため、オーバーフィットのリスクが軽減されます。

例:出会い系 Web サイトのデータセットでは、デシジョン ツリーを使用して、ユーザーが 2 回目のデートを開始するかどうかを予測します。順方向枝刈りにより、デシジョン ツリーの深さを制限し、小さすぎるデータ サブセットに対して生成される分岐が多すぎないようにすることができるため、モデルの汎化能力が向上します。

from sklearn.tree import DecisionTreeClassifier

# 创建一个决策树分类器,并设置最大深度为5
tree_classifier = DecisionTreeClassifier(max_depth=5)

# 训练模型
tree_classifier.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = tree_classifier.predict(X_test)

2. 剪定後

ポスト枝刈りは、完全なデシジョン ツリーを構築した後に不要な枝を削除してツリーの複雑さを軽減することです。ポスト枝刈り手法では、最初に完全に成長した決定木を構築し、次に枝の不純物 (ジニ不純物やエントロピーなど) を計算し、さまざまな枝刈りスキームのパフォーマンスを比較することによって、枝刈りに適切な枝を選択します。この方法は計算量が多くなりますが、多くの場合、より正確な枝刈り結果が得られます。

例:医療診断では、決定木を使用して、患者が特定の病気に罹患しているかどうかを予測します。ポストプルーニングは、最終的な診断にあまり寄与しないブランチを削除するのに役立ち、モデルの理解と解釈が容易になります。

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree

def prune_index(inner_tree, index, threshold):
    if inner_tree.value[index].min() < threshold:
        # 将子树叶子节点设置为空
        inner_tree.children_left[index] = _tree.TREE_LEAF
        inner_tree.children_right[index] = _tree.TREE_LEAF

# 创建一个决策树分类器,并训练完整树
tree_classifier = DecisionTreeClassifier()
tree_classifier.fit(X_train, y_train)

# 设置剪枝的阈值
prune_threshold = 0.01

# 后剪枝
prune_index(tree_classifier.tree_, 0, prune_threshold)

# 在测试集上进行预测
y_pred = tree_classifier.predict(X_test)

違いと概要

事前枝刈りと事後枝刈りはどちらもデシジョン ツリーの過剰適合問題を解決するために使用できますが、実装にはいくつかの違いがあります。

  • 事前枝刈りは、デシジョン ツリーの構築中に実行される手段であり、ツリーの成長中に不要な分岐を回避し、複雑さを制限できます。

  • ポスト枝刈りは、完全な決定木が構築された後に実行され、不必要な枝を削除してツリーの複雑さを軽減します。通常、不純物を計算し、さまざまな枝刈りスキームのパフォーマンスを比較する必要があります。

おすすめ

転載: blog.csdn.net/qq_22841387/article/details/133431866