PyTorch での解釈可能なニューラル ネットワーク モデルの実装

小さな手を動かして大金を稼いで、親指を立ててください!

代替

目的

深層学習システムの解釈可能性の欠如は、人間の信頼を築く上で大きな課題となっています。これらのモデルは複雑であるため、人間が意思決定の背後にある根本的な理由を理解することはほぼ不可能です。

深層学習システムの解釈可能性の欠如は、人間の信頼を妨げます。

この問題に対処するために、研究者たちは新しい解決策を積極的に研究しており、その結果、コンセプトベースのモデルなどの大きな革新が生まれています。これらのモデルは、モデルの透明性を高めるだけでなく、トレーニング プロセス中に人間が解釈可能な高レベルの概念 (「色」や「形状」など) を組み込むことで、システムの意思決定に対する新たな信頼感を促進します。したがって、これらのモデルは、学習した概念に基づいて予測に対するシンプルかつ直感的な説明を提供し、人間が意思決定の背後にある理由を検討できるようにします。それがすべてではありません!人間が学習した概念と対話することも可能になり、最終的な決定を制御できるようになります。

コンセプトベースのモデルを使用すると、人間がディープラーニング予測の背後にある推論を検証し、最終的な決定を制御できるようになります。

このブログ投稿[1]では、これらのテクニックを詳しく説明し、シンプルな PyTorch インターフェイスを使用して最先端の概念ベースのモデルを実装するツールを提供します。実践的な経験を通じて、これらの強力なモデルを活用して解釈可能性を高め、最終的に深層学習システムに対する人間の信頼を調整する方法を学びます。

概念的なボトルネックモデル

この入門では、概念的なボトルネック モデルについて詳しく説明します。2020 年機械学習国際会議で発表された論文で紹介されたこのモデルは、まず「色」や「形状」などの一連の概念を学習して予測し、次にこれらの概念を使用して下流の分類タスクを解決することを目的としています。

代替

このアプローチに従うことで、「入力オブジェクトは、{球形}で{赤}であるため、{リンゴ}である。」などの説明を提供する概念に予測を遡ることができます。

概念的ボトルネック モデルは、最初に「色」や「形状」などの一連の概念を学習し、次にこれらの概念を利用して下流の分類タスクを解決します。

達成

为了说明概念瓶颈模型,我们将重新审视著名的 XOR 问题,但有所不同。我们的输入将包含两个连续的特征。为了捕捉这些特征的本质,我们将使用概念编码器将它们映射为两个有意义的概念,表示为“A”和“B”。我们任务的目标是预测“A”和“B”的异或 (XOR)。通过这个例子,您将更好地理解概念瓶颈如何在实践中应用,并见证它们在解决具体问题方面的有效性。

我们可以从导入必要的库并加载这个简单的数据集开始:

import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

接下来,我们实例化一个概念编码器以将输入特征映射到概念空间,并实例化一个任务预测器以将概念映射到任务预测:

concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(108),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(8, c.shape[1]),
    torch.nn.Sigmoid(),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1], 8),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(81),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

然后我们通过优化概念和任务的交叉熵损失来训练网络:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(2001):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_pred)

    # update loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.2*task_loss

    loss.backward()
    optimizer.step()

训练模型后,我们评估其在测试集上的性能:

c_pred = concept_encoder(x_test)
y_pred = task_predictor(c_pred)

concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
task_accuracy = accuracy_score(y_test, y_pred > 0)

现在,在几个 epoch 之后,我们可以观察到概念和任务在测试集上的准确性都非常好(~98% 的准确性)!

由于这种架构,我们可以通过根据输入概念查看任务预测器的响应来为模型预测提供解释,如下所示:

c_different = torch.FloatTensor([01])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

c_equal = torch.FloatTensor([11])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

这会产生例如 f([0,1])=1 和 f([1,1])=0 ,如预期的那样。这使我们能够更多地了解模型的行为,并检查它对于任何相关概念集的行为是否符合预期,例如,对于互斥的输入概念 [0,1] 或 [1,0],它返回的预测y=1。

概念瓶颈模型通过将预测追溯到概念来提供直观的解释。

淹没在准确性与可解释性的权衡中

概念瓶颈模型的主要优势之一是它们能够通过揭示概念预测模式来为预测提供解释,从而使人们能够评估模型的推理是否符合他们的期望。

然而,标准概念瓶颈模型的主要问题是它们难以解决复杂问题!更一般地说,他们遇到了可解释人工智能中众所周知的一个众所周知的问题,称为准确性-可解释性权衡。实际上,我们希望模型不仅能实现高任务性能,还能提供高质量的解释。不幸的是,在许多情况下,当我们追求更高的准确性时,模型提供的解释往往会在质量和忠实度上下降,反之亦然。

在视觉上,这种权衡可以表示如下:

代替

可解释模型擅长提供高质量的解释,但难以解决具有挑战性的任务,而黑盒模型以提供脆弱和糟糕的解释为代价来实现高任务准确性。

このトレードオフを具体的な設定で説明するために、もう少し要求の厳しいベンチマークである「三角法」データセットに適用される概念的なボトルネック モデルを考えてみましょう。

x, c, y = datasets.trigonometry(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

このデータセットで同じネットワーク アーキテクチャをトレーニングした後、タスクの精度が大幅に低下し、約 80% に達するだけであることが観察されました。

概念的なボトルネック モデルは、タスクの精度と解釈の品質のバランスをとることができません。

ここで疑問が生じます。私たちは説明の正確さと質のどちらかを永遠に選択することを迫られるのでしょうか、それともより良いバランスをとる方法はあるのでしょうか?

参照

[1]

ソース:https://towardsdatascience.com/implement-interpretable-neural-models-in-pytorch-6a5932bdb078

この記事はmdniceマルチプラットフォームによって公開されています

おすすめ

転載: blog.csdn.net/swindler_ice/article/details/131317757