序文
ああ、私は以前にポットテンソルフローの元のコードを見ました。また、テキスト分類のためのグラフ畳み込みネットワークの元のコード解釈を思い出しました[tensorflow]
プロジェクトアドレス
https://github.com/iworldtong/text_gcn.pytorch
環境構成
テンソルフローバージョンの環境で、pytorch 1.7.1 + cu101が再インストールされました
コード分析
remove_words.py
元のコードですか
build_graph.py
これは元のコードでもあります。
好奇心が強い:scipyのcsr_matrix関数も使用されているのを見て、pytorchにも同様の行列演算がありますか?
train.py
いくつかの寸法といくつかの出力
adj (61603, 61603)
features (61603, 300)
y_train (61603, 20)
y_val (61603, 20)
y_test (61603, 20)
train_mask (61603,)
val_mask (61603,)
test_mask (61603,)
train_size 11314
test_size 7532
tm_train_mask torch.Size([61603, 20])
t_support[0] torch.Size([61603, 61603])
pre_sup torch.Size([61603, 200])
support0 torch.Size([61603, 61603])
out torch.Size([61603, 200])
logits * tm_train_mask[0] : tensor([ 0.0521, -0.0080, 0.2177, -0.1337, 0.1672, -0.0428, 0.0664, -0.1221,
0.0376, 0.0709, -0.3589, 0.2038, 0.0118, -0.1365, -0.2384, -0.1432,
0.0838, 0.1781, 0.2771, 0.1930], grad_fn=<SelectBackward>)
t_y_train[0]:tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0.], dtype=torch.float64)
torch.max(t_y_train, 1)[0]:tensor(8)
features = preprocess_features(features)
# return sparse_to_tuple(features)
return features.A
sparse_to_tuple関数を削除しました
def sparse_to_tuple(sparse_mx):
"""Convert sparse matrix to tuple representation."""
def to_tuple(mx):
if not sp.isspmatrix_coo(mx):
mx = mx.tocoo()
coords = np.vstack((mx.row, mx.col)).transpose()
values = mx.data
shape = mx.shape
return coords, values, shape
if isinstance(sparse_mx, list):
for i in range(len(sparse_mx)):
sparse_mx[i] = to_tuple(sparse_mx[i])
else:
sparse_mx = to_tuple(sparse_mx)
return sparse_mx
preprocess_adj(adj)
同じことが言えます
# return sparse_to_tuple(adj_normalized)
return adj_normalized.A
tm_train_mask = torch.transpose(torch.unsqueeze(t_train_mask、0)、1、0).repeat(1、y_train.shape [1])
(real_train_size + valid_size + vocab_size + test_size、)の元のベクトルを(real_train_size + valid_size + vocab_size + test_size、ラベルの数)テンソルに変換します
トレーニング
ああ、省略。
評価
from sklearn import metrics
print_log("Test Precision, Recall and F1-Score...")
print_log(metrics.classification_report(test_labels, test_pred, digits=4))
print_log("Macro average Test Precision, Recall and F1-Score...")
print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='macro'))
print_log("Micro average Test Precision, Recall and F1-Score...")
print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='micro'))
独自のデータセットを構築する-wiki80
再生するには80/10 / 10wiki_727Kを選択します。
adj (7002, 7002)
features (7002, 300)
y_train (7002, 2)
y_val (7002, 2)
y_test (7002, 2)
train_mask (7002,)
val_mask (7002,)
test_mask (7002,)
train_size 3927
test_size 352
独自のデータセットを構築する-wiki800
adj (109772, 109772)
features (109772, 300)
y_train (109772, 2)
y_val (109772, 2)
y_test (109772, 2)
train_mask (109772,)
val_mask (109772,)
test_mask (109772,)
train_size 83497
test_size 4564