ASGCN之图卷积网络(GCN)



前言

本文从使用图卷积网络的目的出发,先对图卷积网络的来源与公式做简要介绍,之后通过一个例子来代码实现图卷积网络。


1. 理论部分

1.1 为什么会出现图卷积网络?

无论是CNN还是RNN,面对的都是规则的数据,面对图这种不规则的数据,原有网络无法对齐进行特征提取,而图这种数据在社会中广泛存在,需要设计一种方法对图数据进行提取,图卷积网络(Graph Convolutional Networks)的出现刚好解决了这一问题。
在这里插入图片描述

1.2 图卷积网络的推导过程

推导部分涉及通信相关知识,其主要核心是时域卷积等价于频域相乘,将时域卷积运算等价到频域进行相乘运算,再将相乘结果转化到时域。GCN的强悍之处在于,即使不训练,完全使用随机初始化的参数W,GCN提取出来的特征就以及十分优秀了。

1.3 图卷积网络的公式

公式由来请参考文献 图卷积网络(Graph Convolutional Networks, GCN)详细介绍,其网络的简易结构如下图所示。
在这里插入图片描述
图卷积的层与层之间的计算公式为:
H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) \pmb{H^{(l+1)}=\sigma ( \tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)} )} H(l+1)=σ(D~21A~D~21H(l)W(l))
式中:

A ~ \tilde{A} A~: A ~ = A + I \tilde{A}=A+I A~=A+I,A为图的邻接矩阵,I为单位矩阵;
D ~ \tilde{D} D~: D ~ \tilde{D} D~ A ~ \tilde{A} A~的度矩阵(degree matrix),表示每个结点度的数量, D i i = ∑ j = 1 i A i j D_{ii}=\sum_{j=1}^iA_{ij} Dii=j=1iAij;
H:每一层的特征,对于输入层,其是X;
σ \sigma σ:非线性激活函数;
W:连接层的权重参数;

2. 代码实现

在ASGCN中卷积层的计算公式为:
h i l = R e l U ( ∑ j = 1 n A i j W l g j l ) d i + 1 + b l ) \pmb{h_i^{l}=RelU(\frac{\sum_{j=1 }^{n} A_{ij} W^lg_{j}^{l})}{d_i+1}+b^l)} hil=RelU(di+1j=1nAijWlgjl)+bl)
依据计算公式构建代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight) # 权重self.weight随机产生
        denom = torch.sum(adj, dim=1, keepdim=True) + 1  # 加一保证做除法时分母不为零
        output = torch.matmul(adj, hidden) / denom
        output = F.relu(output + self.bias)
        print(output)
        return output
def main():
    # 假设该句子经过构建依赖树后的邻接矩阵为adj
    adj =torch.tensor([
        [1., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],
    ])
    # 假设一个句子中有10个单词,从前向后单词对应的索引为[0, 1, 2, 3, 3, 4, 6,0, 1, 2]
    input = torch.tensor([0, 1, 2, 3, 3, 4, 6,0, 1, 2], dtype=torch.long)
    embedding = torch.nn.Embedding(10, 50)
    x = embedding(input)  # 生成每个单词对应的词嵌入,维度为50
    gc1 = GraphConvolution(50, 10)
    gc1(x, adj)
if __name__ == '__main__':
    main()

输出:
tensor([[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21]],
grad_fn=)


参考资料

  1. 图卷积网络 GCN Graph Convolutional Network(谱域GCN)的理解和详细推导
  2. 图卷积网络(Graph Convolutional Networks, GCN)详细介绍

猜你喜欢

转载自blog.csdn.net/qq_40940944/article/details/128709733
今日推荐