SAGPool - Self-Attention Graph Pooling ICML 2019

论文:Self-Attention Graph Pooling

作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
韩国首尔高丽大学计算机科学与工程系

来源:ICML 2019

论文链接:
Arxiv: https://arxiv.org/abs/1904.08082

代码地址:https://github.com/inyeoplee77/SAGPool

本文作者提出一种新的基于self-attention机制的图池化方法SAGPool,方法充分考虑了节点的特征和图的拓扑结构。在图分类评测任务上SAGPool效果拔群。

SAGPool具有前几种方法的优点:分层池化,同时考虑节点特征和图的拓扑结构(因为利用图卷积得到self-attention分数),合理的复杂度,以及端到端表示学习。SAGPool是第一个使用self-attention进行图池化处理并实现高性能的方法。SAGPool参数量一致,不用考虑输入图的大小。

1 相关介绍

Motivation

目前,图池化的方法比图卷积的方法要少,而现存的基于池化的方法存在一些问题:

  • 以往的基于拓扑的池化的研究都只考虑了图的拓扑结构。
  • 而全局池化方法只考虑了图的特征。
  • 分层池化可以学到图的层次表示(DIFFPOOL),允许图神经网络(GNNs)以端到端方式汇聚后获得按比例缩小的图,但是具有O(N^2)的空间复杂度,其参数数量取决于节点数量。分层池化gPool解决了复杂度问题,但没有考虑图的拓扑结构

创新性

文中提出了SAGPool,这是一种基于层次图池化的Self-Attention Graph方法。

  • SAGPool方法可以使用相对较少的参数以端到端方式学习分层表示
  • 利用self-attention机制来区分应该删除的节点和应该保留的节点
  • 基于图卷积计算注意力分数的self-attention机制,考虑了节点特征和图的拓扑结构

背景知识

self-attention 和 masked attention的区别

为什么说文中的attention机制是一种self-attention呢?和GAT中的marsked attention有什么区别呢?

self-attention是一种Global graph attention,会将注意力分配到图中所有的节点上,直接计算图结构中任意两个节点之间的关系,一步到位地获取图结构的全局几何特征。

self−attention利用了attention机制,分三个阶段进行计算:

  • (1) 引入不同的函数和计算机制,根据Query和某个Key,计算两者的相似性或者相关性,最常见的方法包括:求两者的向量点积、求两者的向量Cosine相似性或者通过再引入额外的神经网络来求值;
  • (2) 引入类似softmax的计算方式对第一阶段的得分进行数值转换,一方面可以进行归一化,将原始计算分值整理成所有元素权重之和为11的概率分布;另一方面也可以通过softmaxs的内在机制更加突出重要元素的权重;
  • (3)第二阶段的计算结果ai即为valuei对应的权重系数,然后进行加权求和即可得到attention数值。

通过self-attention注意力机制可以计算任意两个样本的关系,一个样本可以用其他所有样本来表示,但是存在一些问题:
(1)基于空间相似假设,一个样本应与一定范围内的样本关系较密切
(2)样本较多的时候,计算量非常大。

为了解决这上述问题,GAT中使用了一种 masked attention 的方法:对于一个样本来说只利用邻域内的样本计算注意力系数和新的表示,即仅将注意力分配到节点的一阶邻居节点集上

2 相关工作

图数据池化方法可以分为以下三类:基于拓扑的池化、全局池化和分层池化。

基于拓扑的池化

基于拓扑的池化主要考虑了图的结构特征。早期的工作使用的是图的粗化算法,而不是使用神经网络。谱聚类算法利用特征分解得到粗化图。然而,由于特征分解的时间复杂度过大问题,需要一些替代方法:
(1) Weighted graph cuts without eigenvectors a multilevel approach, 2007
(2)在最近的GNN模型中Graclus被用作池化模块:

  • Convolutional neural networks on graphs with fast localized spectral filtering,NIPS 2016
  • Hybrid approach of relation network and localized graph convolutional filtering for breast cancer subtype classification,IJCAI 2018

全局池化

与基于拓扑的池化方法不同,全局池化方法考虑了图的属性特征。全局池化方法使用求和或神经网络对每个层中所有节点的表示进行一次性pool:
(1)Neural message passing for quantum chemistry,2017)
将GNNs视为消息传递方案,提出了一种图分类的通用框架,利用Set2Set方法可以获得整个图的表示。
(2)An end-to-end deep learning architecture for graph classification,AAAI 2018
也叫SortPool,它根据图的结构对节点的embeddings进行排序,并将排序后的embeddings传递给下一层。

分层池化

全局池化方法没有学习对捕获图结构信息至关重要的层次表示。分层池化方法的主要动机是建立一个能够学习每一层中基于特征或拓扑的节点分配模型:
(1)[DIFFPOOL] Hierarchical Graph Representation Learning with Differentiable Pooling,NeurIPS 2018

具体细节,可以参考另一篇博文:[DIFFPOOL 图分类] - Hierarchical Graph Representation Learning with Differentiable Pooling NeurIPS 2018](https://blog.csdn.net/yyl424525/article/details/103307795)

(2)Graph u-net,ICML 2019
gPool实现了与DiffPool相当的性能。gPool需要 O ( ∣ V ∣ + ∣ E ∣ ) O(|V| + |E|) O(V+E)的空间复杂度,而DiffPool需要 O ( k ∣ V ∣ 2 ) O(k|V|^2) O(kV2),其中 V , E , k V,E,k V,E,k分别表示顶点数、边数和池化比率。gPool使用一个可学习的向量 p p p来计算投影分数,然后使用这些分数来从图中选择排名最高的节点保留下来

文中是基于Graph U-Nets的,因此必须先了解,可以参考另一篇博文:Graph U-Nets

为了进一步改进图池化方法,文中提出了SAGPool,它可以使用图的特征和拓扑结构信息来产生具有合理的时间和空间复杂度的层次表示。

3 方法:Self-Attention Graph Pooling(SAGPool)

SAGPool的关键在于它使用GNN来提供self-attention分数。

3.1 基于self-attention的图池化方法:SAGPool

Self-attention mask

使用注意力机制可以关注更重要的特征。self-attention,通常被称为intra-attention,关注的特征是注意力本身。SAGPool利用图卷积的方法得到self-attention分数。例如,使用Kipf的图卷积公式,则self-attention分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} ZRN×1根据如下计算:

Z = σ ( D ~ − 1 2 A ~ D ~ − 1 2 X Θ a t t ) (3) \tag{3} Z=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X \Theta_{a t t}\right) Z=σ(D~21A~D~21XΘatt)(3)

  • Θ a t t ∈ R F × 1 \Theta_{a t t} \in \mathbb{R}^{F \times 1} ΘattRF×1是SAGPool层中的唯一参数
  • 上述公式和GCN公式不同的地方就在于参数矩阵 Θ a t t ∈ R F × 1 \Theta_{a t t} \in \mathbb{R}^{F \times 1} ΘattRF×1和输出矩阵 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} ZRN×1每行都是一列的,而GCN中每行是大于1列的

因为利用公式里融合了A和X的图卷积得到self-attention分数,所以这种池化的结果是基于图的特征和拓扑的

SAGPool采用了gPool中的节点选择方法,保留了输入图的一部分节点
i d x = top-rank ⁡ ( Z , ⌈ k N ⌉ ) , Z m a s k = Z i d x (4) \tag{4} \mathrm{idx}=\operatorname{top-rank}(Z,\lceil k N\rceil), \quad Z_{m a s k}=Z_{\mathrm{idx}} idx=top-rank(Z,kN),Zmask=Zidx(4)

  • 池化比率 k ∈ ( 0 , 1 ] k \in(0,1] k(0,1]是一个超参数,它决定要保留的节点数
  • top ⌈ k N ⌉ \lceil k N\rceil kN的节点是根据 Z Z Z的值来选择的
  • top-rank ⁡ \operatorname{top-rank} top-rank是返回top ⌈ k N ⌉ \lceil k N\rceil kN的节点的索引的函数
  • . i d x ._{idx} .idx是一个索引操作

图池化

池化部分就是根据idx对特征和结构进行topK的选择了:
X ′ = X i d x , : , X o u t = X ′ ⊙ Z m a s k , A o u t = A i d x , i d x (5) \tag{5} X^{\prime}=X_{\mathrm{idx}, \mathrm{:}} \quad, \quad X_{o u t}=X^{\prime} \odot Z_{mask} \quad , \quad A_{o u t}=A_{\mathrm{idx}, \mathrm{idx}} X=Xidx,:,Xout=XZmask,Aout=Aidx,idx(5)

  • X i d x , : X_{\mathrm{idx}, \mathrm{:}} Xidx,:是索引按行排列的特征矩阵 (每一行代表一个节点的特征向量)
  • X o u t X_{o u t} Xout是新的特征矩阵
  • A o u t A_{o u t} Aout是新的邻接矩阵
  • A i d x , i d x A_{\mathrm{idx}, \mathrm{idx}} Aidx,idx是按行和按列索引的邻接矩阵
  • 上图中的Masking操作,就是把根据top-K的id选出来的特征矩阵 X ′ X' X Z m a s k Z_{mask} Zmask进行Hardmard积(同size的矩阵对应元素相乘)

SAGPool的变种 - 使用不同的GNN

SAGPool中使用图卷积的主要原因是为了获得拓扑结构和节点特征。可以使用不同的GNN代替GCN(其他的如GAT、GraphSAGE等),所以计算计算注意力分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} ZRN×1的公式可以泛化为:

Z = σ ( GNN ⁡ ( X , A ) ) (6) \tag{6} Z=\sigma(\operatorname{GNN}(X, A)) Z=σ(GNN(X,A))(6)

计算注意力分数,不仅可以使用相邻节点,也可以使用多跳连接的节点。可以使用添加改变邻接矩阵形式扩展边,堆叠多层GNN层,使用多个注意力分数的平均值等方法来达到这个目的。

以一个连接两跳的节点为例。
(1)添加邻接矩阵的平方:  SAGPool augmentation \text { SAGPool}_{\text {augmentation}}  SAGPoolaugmentation
式(7)使用了两跳连接,该连接涉及边的扩展,允许两跳节点的间接聚合。添加邻接矩阵的平方相当于在两跳邻居之间创建了边:

Z = σ ( GNN ⁡ ( X , A + A 2 ) ) (7) \tag{7} Z=\sigma\left(\operatorname{GNN}\left(X, A+A^{2}\right)\right) Z=σ(GNN(X,A+A2))(7)

(2)叠加两层GNN层:  SAGPool serial \text { SAGPool}_{\text {serial}}  SAGPoolserial
式(8)使用了两跳连接,该连接涉及GNN层的堆叠,允许两跳节点的间接聚合。在这种情况下,SAGPool层的非线性和参数数量将增加:

Z = σ ( GNN ⁡ 2 ( σ ( GNN ⁡ 1 ( X , A ) ) , A ) ) (8) \tag{8} Z=\sigma\left(\operatorname{GNN}_{2}\left(\sigma\left(\operatorname{GNN}_{1}(X, A)\right), A\right)\right) Z=σ(GNN2(σ(GNN1(X,A)),A))(8)
公式(7)和公式(8)可以应用到更多跳的连接上。

(3)取多重注意力分数的平均值,类似于Multi-head GAT:  SAGPool parallel \text { SAGPool}_{\text {parallel}}  SAGPoolparallel
M M M个GNNs平均注意力分值:

Z = 1 M ∑ m σ ( G N N m ( X , A ) ) (9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right) Z=M1mσ(GNNm(X,A))(9)

文中分别将公式(7),(8),(9)中的模型称为  SAGPool augmentation \text { SAGPool}_{\text {augmentation}}  SAGPoolaugmentation  SAGPool serial \text { SAGPool}_{\text {serial}}  SAGPoolserial  SAGPool parallel \text { SAGPool}_{\text {parallel}}  SAGPoolparallel

看看SAGPool的源代码:
$$

torch_geometric据说是GNN的神器,这里面甚至封装了常用的GNN模型

from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
from torch.nn import Parameter
import torch

class SAGPool(torch.nn.Module):
def init(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh):
super(SAGPool,self).init()
self.in_channels = in_channels
self.ratio = ratio # 论文中的参数k
self.score_layer = Conv(in_channels,1) # 论文中的Z
self.non_linearity = non_linearity
def forward(self, x, edge_index, edge_attr=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
#x = x.unsqueeze(-1) if x.dim() == 1 else x
score = self.score_layer(x,edge_index).squeeze()
perm = topk(score, self.ratio, batch) # topk选择最大的几个
x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) # mask
batch = batch[perm]
edge_index, edge_attr = filter_adj( # 选择子图结构特征
edge_index, edge_attr, perm, num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm
$$

3.2 模型架构

卷积层

上文提到了GNN可以有很多种,如GAT、GraphSAGE等,本文还是用了Kipf的卷积:
h ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 h ( l ) Θ ) (10) \tag{10} h^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} \Theta\right) h(l+1)=σ(D~21A~D~21h(l)Θ)(10)

  • Θ ∈ R F × F ′ \Theta \in \mathbb{R}^{F \times F^{\prime}} ΘRF×F
  • F F F F ′ F^{\prime} F分别表示第 l + 1 l+1 l+1层的输入特征维度和输出特征维度
  • 激活函数使用ReLU

readout层

受JK-net架构(Representation learning on graphs with jumping knowledge networks,2018;Towards sparse hierarchical graph classifiers,2018)的启发,提出了一种readout层,该层聚合节点特征以形成固定大小的表示。readout层的输出特征如下:

s = 1 N ∑ i = 1 N x i ∥ max ⁡ i = 1 N x i (11) \tag{11} s=\frac{1}{N} \sum_{i=1}^{N} x_{i} \| \max _{i=1}^{N} x_{i} s=N1i=1Nxii=1maxNxi(11)

  • N N N表示节点数量, x i x_i xi表示第 i i i个节点的特征
  • ∣ ∣ || ,表示concatenation,即特征串联操作,这个式子对每个节点的embedding,做了两种池化——平均池化和最大池化,然后拼接起来

代码为:
f r o m t o r c h g e o m e t r i c . n n i m p o r t g l o b a l m e a n p o o l a s g a p , g l o b a l m a x p o o l a s g m p x 1 = t o r c h . c a t ( [ g m p ( x , b a t c h ) , g a p ( x , b a t c h ) ] , d i m = 1 ) from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) fromtorchgeometric.nnimportglobalmeanpoolasgap,globalmaxpoolasgmpx1=torch.cat([gmp(x,batch),gap(x,batch)],dim=1)

全局池化架构

实现了(An end-to-end deep learning architecture for graph classification,AAAI 2018)中提出的全局池化架构。

分层池化架构

实现了(Towards sparse hierarchical graph classifiers,2018)分层池化架构。

对于具体的模型架构,本文使用了两种Global pooling architecture和Hierarchical pooling architecture:

  • 图2左图是全局池化结构,右图是分层池化结构,在后面的实验中,分别用下标g和h标示使用的是两种不同的池化结构
  • 全局池化结构由三个图卷积层组成,每层的输出被连接起来。节点特征在池化层之后的readout层中聚合。然后将图的特征表示传递到线性层进行分类。
  • 分层池化架构由三个block组成,每个block由一个图卷积层和一个图池化层组成。每个block的输出汇总在readout层中。将每个readout层输出的总和输入到线性层进行分类。

4 实验

在图分类任务中,评估了全局池化和分层池化方法。

数据集

选取了5个图的数量较大的数据集( > 1 k > 1k >1k):

  • D&D
    包含蛋白质结构的图。一个节点表示一个氨基酸,如果两个节点之间的距离小于 6 A ^ 6 \hat{\mathrm{A}} 6A^,则构造一条边。标签表示蛋白质是酶还是非酶。
  • PROTEINS
  • 也是一组蛋白质,其中节点是二级结构元素。
  • NCI1
    是一个用于抗癌活性分类的生物数据集。在数据集中,每个图表示一个化合物,节点和边分别表示原子和化学键。
  • NCI109
  • FRANKENSTEIN
    是一组具有包含连续值的节点特征的分子图。标签表示一个分子是诱变剂还是非诱变剂(凡是能引起生物体遗传物质发生突然或根本的改变,使其基因突变或染色体畸变达到自然水平以上的物质,统称为诱变剂)。

GNNs的评估

  • 所有模型均采用相同的early stopping准则和超参数选择策略,以保证比较的公平性。
  • 使用NVIDIA TitanXp GPU
  • 使用几何深度学习扩展库PyG实现所有的baselines和SAGPool

训练过程

  • (Pitfalls of graph neural network evaluation,2018)证明了不同的数据分割会影响GNN模型的性能。
  • 在实验中,使用10-fold交叉验证评估了超过20个随机种子的池化方法。
  • 总共使用了200个测试结果来获取每个数据集上每个方法的最终精度。
  • 使用训练集的10%的数据进行验证。
  • 使用了Adam优化器、early stopping准则、patience以及全局池化结构和分层池化结构的超参数选择策略。如果在epoch终止条件下最多100k个epoch验证损失没有改善,将停止训练。
  • 通过网格搜索得到最优超参数。网格搜索的范围如表2所示。

Baselines

  • Set2Set
    Set2Set需要额外的超参数,即LSTM模块的处理step数。LSTM模块为节点顺序不变的图生成embedding,所以可以假设readout层是不必要的。
  • SortPool
    SortPool是一种全局池化方法,它对节点进行排序来进行池化。节点数量设置为 K K K,使得图的60%有多于 K K K个的节点。在全局池化设置中,  SAGPool  g \text { SAGPool }_{g}  SAGPool g与SortPool具有相同的 K K K个输出节点。
  • DiffPool
    DiffPool是第一种端到端可训练的图池化方法,它可以生成图的分层表示。使用中没有对DiffPool使用batch normalization,因为这与池化方法无关。对于超参数搜索,池化比率从0.25到0.5不等。在引用的实现中,cluster大小设置为节点的最大数目的25%。当池化比率大于0.5时,DiffPoolh会导致内存不足。
  • gPool
    gPool为池化选择排名靠前的节点,与SAGPool方法类似。通过与gPool的比较表明,考虑拓扑结构有助于提高图形分类任务的性能。

5 分析

全局池化和分层池化

很难确定全局池化结构或层次池化结构是否完全有利于图形分类。因为全局池化结构 P O O L g P O O L_{g} POOLg  SAGPool  g \text { SAGPool }_{g}  SAGPool g  SortPool  g \text { SortPool }_{g}  SortPool g  Set2Set  g \text { Set2Set }_{g}  Set2Set g)使信息丢失最小化,因此它在节点较少的数据集(NCI1、NCI109、FRANKENSTEIN)上的性能优于分层池化结构 P O O L h P O O L_{h} POOLh  SAGPool  h \text { SAGPool }_{h}  SAGPool h  gPool  h \text { gPool }_{h}  gPool h  DiffPool  h \text { DiffPool }_{h}  DiffPool h

但是,分层池化 P O O L h P O O L_{h} POOLh对节点数较多的数据集(D&D和PROTEINS)更有效,因为它能有效地从大规模图中提取有用的信息。因此,使用最适合给定数据的池化结构非常重要。尽管如此,SAGPool在每种架构中通常都表现良好。

考虑图拓扑结构的影响

和gPool不一样, SAGPool使用一阶近似图的拉普拉斯算子 D ~ − 1 2 A ~ D ~ − 1 2 \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} D~21A~D~21,这使得SAGPool考虑了图的拓扑结构。如表3所示,考虑图的拓扑结构可以提高性能。此外,图的Laplacian算子不需要重新计算,因为它在前一个图卷积层中也使用了,可以预先计算。

虽然SAGPool具有与gPool相同的参数,但它在图分类任务中表现出了更优异的性能。

结果表明,SAGPool在总体上表现良好,在D&D和PROTEINS方面表现尤为突出。在实验中,SAGPool在所有的数据集上都优于分层池化的方法。

稀疏实现

使用稀疏矩阵操作图数据对于GNNs来说非常重要,因为邻接矩阵通常是稀疏的。
用稠密矩阵计算图卷积时,乘法 A X AX AX的计算复杂度为 O ( ∣ V ∣ 2 ) O(|V|^2) O(V2),其中 A A A为邻接矩阵, X X X为节点特征矩阵, V V V为顶点。如(Towards sparse hierarchical graph classifiers,2018)所述,密集矩阵池化会导致内存效率问题。

如果在同一操作中使用稀疏矩阵,则计算复杂度降低到 O ( ∣ E ∣ ) O(|E|) O(E),其中 E E E表示边。由于SAGPool是一种稀疏池化方法,使用稀疏实现可以降低计算复杂度,而DiffPool是一种密集池化方法,计算复杂度较高

稀疏性也影响空间复杂性。因为SAGPool使用GNN来获取注意力分数,所以SAGPool需要 O ( ∣ V ∣ + ∣ E ∣ ) O(|V |+|E|) O(V+E)的稀疏池化存储空间,而稠密池化方法需要 O ( ∣ V ∣ 2 ) O(|V|^2) O(V2)

节点数量的关系

在DiffPool中,由于GNN产生了assignment矩阵S,因此在构建模型时必须定义cluster的大小。根据参考方法的实现,cluster的大小必须与最大节点数成比例。DiffPool的这些要求会导致两个问题。

  • 参数的数量取决于最大节点数,如图3所示。
  • 当节点数量变化很大时,很难确定正确的cluster大小。例如,在1178个图中只有10个图具有超过1000个节点,其中最大节点数为5748,最小节点数为30。如果池化比率为10%,则cluster大小为574,这将扩展大多数数据池化后的图大小。

在SAGPool中,参数的数量与cluster的大小无关。此外,可以根据输入节点的数量更改cluster大小。

  • 图3:图中参数的数量随着节点的增多而增大。
  • x轴标号为输入图节点数
  • y轴为分层池化模型参数:输入节点特征数为128,隐含层特征大小为128,class数为2。
  • SAGPool使用公式(3)的图卷积。
  • k k k表示池化比率, k = 1.0 k = 1.0 k=1.0表示池化后保留整个节点。
  • 无论输入图的大小和池化比率如何,gPool和SAGPool参数数量都一致。

SAGPool变种比较

为了研究SAGPool方法的潜力,在两个数据集上评估了SAGPool的变种。可以用以下操作修改SAGPool:

  • 更改GNN的类型
  • 考虑两跳连接。实验中使用两个连续GNN层(  SAGPool  s e r i a l \text { SAGPool }_{serial}  SAGPool serial)和添加邻接矩阵的方式来实现2跳连接(  SAGPool  a u g m e n t a t i o n \text { SAGPool }_{augmentation}  SAGPool augmentation)。
  • 对多个GNN的注意力得分求平均值
  • 使用的数据集和SAGPool中的GNN类型不同,图分类的性能有所不同
  • 两跳邻居的信息有助于提高性能
  • 发现为数据集选择合适的 M M M值有助于实现稳定的性能

当前方法的局限性

  • 保留一定比例(池化比率 k k k)的节点来处理不同大小的不同输入图,这在之前的研究中也做过
  • 在SAGPool中,无法为每个图将池化比率参数化来找到最优值。为了解决这个问题,文中使用二分类来决定保留哪些节点,但是这并没有完全解决问题。

可能的扩展工作

  • 为每个图使用可学习的池化比率来获得最优的cluster大小
  • 研究每个池化层中多attention mask的影响

猜你喜欢

转载自blog.csdn.net/yyl424525/article/details/112340961