ASGCN のグラフ畳み込みネットワーク (GCN)



序文

この記事では、グラフ畳み込みネットワークを使用する目的から始めて、まずグラフ畳み込みネットワークのソースと式を簡単に紹介し、次に例を使用してグラフ畳み込みネットワークをコーディングします。


1.理論部分

1.1 なぜグラフ畳み込みネットワークがあるのですか?

CNNにせよRNNにせよ、通常のデータに直面する.グラフのような不規則なデータに直面すると、元のネットワークを整列させて特徴抽出を行うことができない.しかし、グラフのようなデータは社会に広く存在し、その方法が必要である.グラフ データを抽出するために、Graph Convolutional Networks の出現により、この問題が解決されました。
ここに画像の説明を挿入

1.2 グラフ畳み込みネットワークの導出過程

導出部分には通信関連の知識が含まれます. 主なコアは、時間領域の畳み込みが周波数領域の乗算に相当することです. 時間領域の畳み込み演算は乗算の周波数領域に相当し、乗算結果は次のように変換されます.タイムドメイン。GCN の威力は、トレーニングを受けていなくても、パラメータ W が完全にランダムに初期化されていても、GCN によって抽出された特徴が非常に優れていることです。

1.3 グラフ畳み込みネットワークの公式

式の由来については、ドキュメント. ネットワークの簡単な構造を下の図に示します.
ここに画像の説明を挿入
グラフ畳み込みの層間の計算式は次のとおりです。
H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) \pmb{H^{ (l+1 )}=\sigma ( \tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H ^{(l )}W^{(l)} )}H( l + 1 )=s (D21D21H( l ) W( l ) )
ここで:

A ~ \tilde{A}~ :A ~ = A + I \tilde{A}=A+I=+I、A はグラフの隣接行列、I は単位行列です;
D ~ \tilde{D}D~ :D ~ \tilde{D}D~A ~ \tilde{A}~の次数行列 (次数行列)は、各ノードの次数を表し、D ii = ∑ j = 1 i A ij D_{ii}=\sum_{j=1}^iA_{ij}Dii=j = 1ij;
H: 各層の特徴、入力層の場合は X;
σ \sigmaσ : 非線形活性化関数;
W: 接続層の重みパラメータ;

2. コードの実装

ASGCN における畳み込み層の計算式は次のとおりです。
hil = R el U ( ∑ j = 1 n A ij W lgjl ) di + 1 + bl ) \pmb{h_i^{l}=RelU(\frac{\sum_{ j=1 }^{n} A_{ij} W^lg_{j}^{l})}{d_i+1}+b^l)}時間l=R e l U (d+1j = 1nijWラグ_jl)+bl )
次の計算式に従ってコードを作成します。

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight) # 权重self.weight随机产生
        denom = torch.sum(adj, dim=1, keepdim=True) + 1  # 加一保证做除法时分母不为零
        output = torch.matmul(adj, hidden) / denom
        output = F.relu(output + self.bias)
        print(output)
        return output
def main():
    # 假设该句子经过构建依赖树后的邻接矩阵为adj
    adj =torch.tensor([
        [1., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],
    ])
    # 假设一个句子中有10个单词,从前向后单词对应的索引为[0, 1, 2, 3, 3, 4, 6,0, 1, 2]
    input = torch.tensor([0, 1, 2, 3, 3, 4, 6,0, 1, 2], dtype=torch.long)
    embedding = torch.nn.Embedding(10, 50)
    x = embedding(input)  # 生成每个单词对应的词嵌入,维度为50
    gc1 = GraphConvolution(50, 10)
    gc1(x, adj)
if __name__ == '__main__':
    main()

输出:
tensor([[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02 、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、6.4069e +02、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、 6.4069e+02、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+ 19、6.4069e+02、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547 e+19、6.4069e+02、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32 、1.1547e+19、6.4069e+02、4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、6.4069e+02、4.3066e+21 ]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、6.4069e+02、4.3066e +21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、6.4069e+02、 4.3066e+21]、
[1.1561e+19、6.8794e+11、2.7253e+20、3.0866e+29、1.1547e+19、4.1988e+07、3.0357e+32、1.1547e+19、6.4069e+ 02、4.3066e+21]]、
grad_fn =)


参考文献

  1. Graph Convolutional Network GCN の理解と詳細な導出 Graph Convolutional Network (スペクトル ドメイン GCN)
  2. Graph Convolutional Networks (GCN) の詳細な紹介

おすすめ

転載: blog.csdn.net/qq_40940944/article/details/128709733