GNNExplainer コード解釈とその PyG 実装
前回のグラフ ニューラル ネットワークの解釈方法と GNNexplainer コード例に関するブログ投稿に引き続き、ここでは GNNExplainer のソース コードを単純に分析し、PyTorch Geometric を使用して手動で実装します。
GNNExplainer のソースアドレス: https://github.com/RexYing/gnn-model-explainer
GNN エクスプローラーを使用する
(1) インストール:
git clone https://github.com/RexYing/gnn-model-explainer
Python3.7 を使用して仮想環境を作成することをお勧めします。
virtualenv venv -p /usr/local/bin/python3
source venv/bin/activate
(2) GCN モデルをトレーニングする
python train.py --dataset=EXPERIMENT_NAME
ここで、EXPERIMENT_NAME は、再現する実験の名前を表します。
GCN モデルをトレーニングするためのオプションの完全なリスト:
python train.py --help
(3) GCN モデルの解釈
インタープリタを実行するには、次のコマンドを実行します。
python explainer_main.py --dataset=EXPERIMENT_NAME
(4) Tensorboard を使用した視覚的解釈: Tensorboard を通じて最適化された結果を視覚化できます。
tensorboard --logdir log
GNNExplainer ソースコードの高速読み取り
GNNExplainer では、次の 2 つの角度からグラフを説明します。
- エッジ (エッジ): グラフに表示される各エッジの確率を示すエッジ マスクが生成されます。値は 0 ~ 1 の浮動小数点数です。エッジ マスクは重みとして使用することもできます。これは、topk のエッジによって接続された部分グラフを取ることで説明できます。
- ノード特徴: ノード特徴 (NF) はノード ベクトルです。たとえば、128 次元のノードは 128 の特徴を表し、同時に各特徴の重みを表す NF マスクを生成します。これはオプションです。
-
Explainer ディレクトリ内のクラスは
ExplainModel
、torch.nn.Module を継承して GNNExplainer ネットワークのモジュール構造を定義します。init
初期化時には、construct_edge_mask
sumconstruct_feat_mask
関数を使用して学習対象の2つ(それぞれn×nn×nの2種類の変数mask
に対応)を初期化しますnn.Parameter
n×n次元mask
、d
すべて 0feat_mask
)、diag_mask
つまり、主対角が 0、残りの要素が 1 の行列で、_masked_adj
関数に使用されます。_masked_adj
この関数はmask
sigmod または ReLU でアクティブ化され、独自の転置を追加して 2 で割って対称行列に変換し、それを乗算して、diag_mask
最後に元の隣接行列adj
を に変換しますmasked_adj
。
-
Explainer
このクラスは説明ロジックを実装しており、main 関数はその 1 つでありexplain
、単一ノードでの元のモデルの予測結果を説明するために使用されます。主な手順は次のとおりです。- 部分グラフの
adj
,x
,を取り出しますlabel
。グラフの説明:graph_idx
対応する計算グラフ全体を取得します;ノードの説明:extract_neighborhood
関数を呼び出してnum_gc_layers
ノード次数の近傍を取得します。 - 受信モデル予測出力を に
pred
変換しますpred_label
。 - build
ExplainModule
、num_epochs
ラウンド トレーニングを実行します (フォワード + バックプロップ)
- 部分グラフの
adj = torch.tensor(sub_adj, dtype=torch.float)
x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
label = torch.tensor(sub_label, dtype=torch.long)
if self.graph_mode:
pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
print("Graph predicted label: ", pred_label)
else:
pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)
print("Node predicted label: ", pred_label[node_idx_new])
explainer = ExplainModule(
adj=adj,
x=x,
model=self.model,
label=label,
args=self.args,
writer=self.writer,
graph_idx=self.graph_idx,
graph_mode=self.graph_mode,
)
if self.args.gpu:
explainer = explainer.cuda()
...
# NODE EXPLAINER
def explain_nodes(self, node_indices, args, graph_idx=0):
...
def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"):
...
# GRAPH EXPLAINER
def explain_graphs(self, graph_indices):
...
explain_nodes
、explain_nodes_gnn_stats
、explain_graphs
これら3つの機能をベースに実現されています。
forward
sum関数は以下で分析されますloss
。
順伝播
まず、学習対象のパラメータマスクとfeat_maskに、元の隣接行列と特徴ベクトルを乗算して、変換後の合計を求めmasked_adj
ますx
。前者は_masked_adj
関数を呼び出すことによって実行され、後者は次のように実装されます。
feat_mask = (
torch.sigmoid(self.feat_mask)
if self.use_sigmoid
else self.feat_mask
)
if marginalize:
std_tensor = torch.ones_like(x, dtype=torch.float) / 2
mean_tensor = torch.zeros_like(x, dtype=torch.float) - x
z = torch.normal(mean=mean_tensor, std=std_tensor)
x = x + z * (1 - feat_mask)
else:
x = x * feat_mask
完全なコードは次のとおりです。
ここで説明する必要があるのは True の場合です。marginalize
論文のバイナリ特徴セレクター F の学習を参照してください。
mask
同じことを学習するとfeature_mask
、場合によっては重要な特徴が無視されることがあります(学習された特徴マスクも 0 に近い値です)。そのため、XS X_Sによれば、バツSモンテカルロ法の経験的周辺分布を使用して、X = XSFX=X_S^Fをサンプリングします。バツ=バツSふ。- 確率変数XXを解くにはXのバックプロパゲーションの問題には、「再パラメータ化」の手法が導入されています。つまり、パラメータのない確率変数ZZZの確定変換: X = Z + ( XS − Z ) ⊙ FX=Z+(X_S-Z)\odot Fバツ=Z+( XS−Z )⊙F s . と。∑ j F j ≤ KF st \sum_{j}F_j\le K_Fs 。と。j∑Fj≤Kふ
其中, Z Z Zは経験的分布に従ってサンプリングして得られたddd次元の確率変数、KF K_FKふは、保持する特徴の最大数を表すパラメータ (utils/io_utils.py
の関数denoise_graph
)です。
次に、masked_adj
その合計をx
元のモデルに入力してExplainModule
結果を取得しますpred
。
損失関数
loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss
合計損失には 5 つの項目が含まれていることがわかります。論文に対応する損失関数の式を除きpred_loss
、他の損失の関数は論文「追加の制約を説明に統合する」を参照してください。それらの重みは coeff で定義されています。
self.coeffs = {
"size": 0.005,
"feat_size": 1.0,
"ent": 1.0,
"feat_ent": 0.1,
"grad": 0,
"lap": 1.0,
}
pred_loss
mi_obj = False
if mi_obj:
pred_loss = -torch.sum(pred * torch.log(pred))
else:
pred_label_node = pred_label if self.graph_mode else pred_label[node_idx]
gt_label_node = self.label if self.graph_mode else self.label[0][node_idx]
logit = pred[gt_label_node]
pred_loss = -torch.log(logit)
ここでpred
、 は現在の予測結果、pred_label
は元のフィーチャの予測結果です。
mask_ent_loss
# entropy
mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
size_loss
# size
mask = self.mask
if self.mask_act == "sigmoid":
mask = torch.sigmoid(self.mask)
elif self.mask_act == "ReLU":
mask = nn.ReLU()(self.mask)
size_loss = self.coeffs["size"] * torch.sum(mask)
feat_size_loss
# pre_mask_sum = torch.sum(self.feat_mask)
feat_mask = (
torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
)
feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
lap_loss
# laplacian
D = torch.diag(torch.sum(self.masked_adj[0], 0))
m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx]
L = D - m_adj
pred_label_t = torch.tensor(pred_label, dtype=torch.float)
if self.args.gpu:
pred_label_t = pred_label_t.cuda()
L = L.cuda()
if self.graph_mode:
lap_loss = 0
else:
lap_loss = (self.coeffs["lap"] * (pred_label_t @ L @ pred_label_t) / self.adj.numel())
GNNExplainer グラフ分類の説明に基づく PyG コード例
グラフ分類問題を説明するには、次の 2 つの重要なポイントがあります。
- 学習するマスクは部分グラフを取得せずにグラフ全体に作用します
- ラベル予測と損失関数の対象は単一のグラフです
実装コードは次のとおりです。
#!/usr/bin/env python
# encoding: utf-8
# Created by BIT09 at 2023/4/28
import torch
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from tqdm import tqdm
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, to_networkx
EPS = 1e-15
class GNNExplainer(torch.nn.Module):
r"""
Args:
model (torch.nn.Module): The GNN module to explain.
epochs (int, optional): The number of epochs to train.
(default: :obj:`100`)
lr (float, optional): The learning rate to apply.
(default: :obj:`0.01`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
"""
coeffs = {
'edge_size': 0.001,
'node_feat_size': 1.0,
'edge_ent': 1.0,
'node_feat_ent': 0.1,
}
def __init__(self, model, epochs=100, lr=0.01, log=True, node=False): # disable node_feat_mask by default
super(GNNExplainer, self).__init__()
self.model = model
self.epochs = epochs
self.lr = lr
self.log = log
self.node = node
def __set_masks__(self, x, edge_index, init="normal"):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1
if self.node:
self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1)
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)
self.edge_mask = torch.nn.Parameter(torch.zeros(E) * 50)
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = self.edge_mask
def __clear_masks__(self):
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = False
module.__edge_mask__ = None
if self.node:
self.node_feat_masks = None
self.edge_mask = None
def __num_hops__(self):
num_hops = 0
for module in self.model.modules():
if isinstance(module, MessagePassing):
num_hops += 1
return num_hops
def __flow__(self):
for module in self.model.modules():
if isinstance(module, MessagePassing):
return module.flow
return 'source_to_target'
def __subgraph__(self, node_idx, x, edge_index, **kwargs):
num_nodes, num_edges = x.size(0), edge_index.size(1)
if node_idx is not None:
subset, edge_index, mapping, edge_mask = k_hop_subgraph(
node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
num_nodes=num_nodes, flow=self.__flow__())
x = x[subset]
else:
x = x
edge_index = edge_index
row, col = edge_index
edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
edge_mask[:] = True
mapping = None
for key, item in kwargs:
if torch.is_tensor(item) and item.size(0) == num_nodes:
item = item[subset]
elif torch.is_tensor(item) and item.size(0) == num_edges:
item = item[edge_mask]
kwargs[key] = item
return x, edge_index, mapping, edge_mask, kwargs
def __graph_loss__(self, log_logits, pred_label):
loss = -torch.log(log_logits[0, pred_label])
m = self.edge_mask.sigmoid()
loss = loss + self.coeffs['edge_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
return loss
def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
threshold=None, **kwargs):
r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
:attr:`edge_mask`.
Args:
node_idx (int): The node id to explain.
edge_index (LongTensor): The edge indices.
edge_mask (Tensor): The edge mask.
y (Tensor, optional): The ground-truth node-prediction labels used
as node colorings. (default: :obj:`None`)
threshold (float, optional): Sets a threshold for visualizing
important edges. If set to :obj:`None`, will visualize all
edges with transparancy indicating the importance of edges.
(default: :obj:`None`)
**kwargs (optional): Additional arguments passed to
:func:`nx.draw`.
:rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
"""
assert edge_mask.size(0) == edge_index.size(1)
if node_idx is not None:
# Only operate on a k-hop subgraph around `node_idx`.
subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
num_nodes=None, flow=self.__flow__())
edge_mask = edge_mask[hard_edge_mask]
subset = subset.tolist()
if y is None:
y = torch.zeros(edge_index.max().item() + 1,
device=edge_index.device)
else:
y = y[subset].to(torch.float) / y.max().item()
y = y.tolist()
else:
subset = []
for index, mask in enumerate(edge_mask):
node_a = edge_index[0, index]
node_b = edge_index[1, index]
if node_a not in subset:
subset.append(node_a.item())
if node_b not in subset:
subset.append(node_b.item())
y = [y for i in range(len(subset))]
if threshold is not None:
edge_mask = (edge_mask >= threshold).to(torch.float)
data = Data(edge_index=edge_index, att=edge_mask, y=y,
num_nodes=len(y)).to('cpu')
G = to_networkx(data, edge_attrs=['att']) # , node_attrs=['y']
mapping = {
k: i for k, i in enumerate(subset)}
G = nx.relabel_nodes(G, mapping)
kwargs['with_labels'] = kwargs.get('with_labels') or True
kwargs['font_size'] = kwargs.get('font_size') or 10
kwargs['node_size'] = kwargs.get('node_size') or 800
kwargs['cmap'] = kwargs.get('cmap') or 'cool'
pos = nx.spring_layout(G)
ax = plt.gca()
for source, target, data in G.edges(data=True):
ax.annotate(
'', xy=pos[target], xycoords='data', xytext=pos[source],
textcoords='data', arrowprops=dict(
arrowstyle="->",
alpha=max(data['att'], 0.1),
shrinkA=sqrt(kwargs['node_size']) / 2.0,
shrinkB=sqrt(kwargs['node_size']) / 2.0,
connectionstyle="arc3,rad=0.1",
))
nx.draw_networkx_nodes(G, pos, node_color=y, **kwargs)
nx.draw_networkx_labels(G, pos, **kwargs)
return ax, G
def explain_graph(self, data, **kwargs):
self.model.eval()
self.__clear_masks__()
x, edge_index, batch = data.x, data.edge_index, data.batch
num_edges = edge_index.size(1)
# Only operate on a k-hop subgraph around `node_idx`.
x, edge_index, _, hard_edge_mask, kwargs = self.__subgraph__(node_idx=None, x=x, edge_index=edge_index,
**kwargs)
# Get the initial prediction.
with torch.no_grad():
log_logits = self.model(data, **kwargs)
probs_Y = torch.softmax(log_logits, 1)
pred_label = probs_Y.argmax(dim=-1)
self.__set_masks__(x, edge_index)
self.to(x.device)
if self.node:
optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)
else:
optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)
epoch_losses = []
for epoch in range(1, self.epochs + 1):
epoch_loss = 0
optimizer.zero_grad()
if self.node:
h = x * self.node_feat_mask.view(1, -1).sigmoid()
log_logits = self.model(data, **kwargs)
pred = torch.softmax(log_logits, 1)
loss = self.__graph_loss__(pred, pred_label)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_losses.append(epoch_loss)
edge_mask = self.edge_mask.detach().sigmoid()
print(edge_mask)
self.__clear_masks__()
return edge_mask, epoch_losses
def __repr__(self):
return f'{
self.__class__.__name__}()'
参考文献
- gnn 説明者
- グラフ ニューラル ネットワークの解釈方法と GNNexplainer コード例
- Pytorch は GNNExplainer を実装します
- グラフ ニューラル ネットワークを説明する方法 — GNNExplainer
- https://gist.github.com/hongxuenong/9f7d4ce96352d4313358bc8368801707