《Fast graph representation learning with PyTorch Geometric》阅读笔记【PyG-paper】

这篇文章是PyG的官方paper,里面提供了很多有价值的信息,在这里进行一下汇总和思考。

Fey M, Lenssen J E. Fast graph representation learning with PyTorch Geometric[J]. arXiv preprint arXiv:1903.02428, 2019.

PyG主要是利用消息传递框架的思想来「提供API接口」以及「实现高GPU吞吐量」(加速)的。

图数据结构

G = ( X , ( I , E ) ) , X ∈ R N × F , I ∈ N 2 × E , E ∈ R E × D G=(X,(I,E))\quad ,X\in R^{N\times F},I\in N^{2\times E},E\in R^{E\times D} G=(X,(I,E)),XRN×F,IN2×E,ERE×D
X就是我们常说的节点特征矩阵data.x,I是之前见过的COO格式的边索引表data.edge_index,E是边的特征矩阵(在节点分类任务中用不到)。当然,I也可以表示成SpareTensor的形式,这个后面也会说什么时候需要用到这种形式。

消息传递框架MPNN

x i ′ = γ ( x i , □ j ∈ N ( i )   ϕ ( x i , x j , e j , i ) ) , \mathbf{x}_i^{\prime} = \gamma \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right), xi=γ(xi,jN(i)ϕ(xi,xj,ej,i)),
消息传递其实就是一种邻域聚合操作。

符号 函数
ϕ \phi ϕ message(MLP)
□ \square aggregate(add、mean、max)
γ \gamma γ update(MLP)

在这里插入图片描述

需要重点说明的是聚合过程中使用的GS(gather and scatter)操作。在附录A中,重点将GS方法和SpMM(Sparse matrix multiplication)方法进行了对比。

GS方法

GS方法是在PyG中提出的用于模型加速的操作。

GS操作因为首先要将节点映射到边空间上,所以不可避免的会产生非对齐(non-coalesced)内存访问,这样就会导致不能使用连续的地址。但是,很高的边级并行性加上原子操作,让GS操作能够获得很高的数据吞吐量。并且通过实验也证明了,相比于SpMM,原子操作在节点平均度数>=128的时候,能够显著的节约运行的时间。

GS方法除了速度很快,它的灵活性也很好,可以很容易的加上self-loop、双向边、邻域聚合等等。

但是,GS方法也并非十全十美。在稠密的大图,特别是边很多的时候,GS方法会导致很高的内存占用,甚至会导致内存溢出。这个时候,SpMM就可以派上用场了,用时间换空间。

SpMM方法

SpMM方法其实就是稀疏矩阵乘法的操作。

一般来说,首先要先对邻接矩阵A进行预处理,得到CSR格式的存储格式,然后再进行稀疏矩阵的各项运算。并且,在backward阶段因为需要对齐(coalesced),所以也会花费大量的时间。综合看来,SpMM方法也就会比GS方法慢很多了。

虽然SpMM运行时间更长,但是在图比较大且稠密的时候,它的内从占用要比GS方法小得多,从而避免内存溢出的现象发生。

GNN层和模型

我们如果去查看官方文档,会发现里面已经把一些比较经典且重要的图卷积层GConv实现好了,包括但不限于ChebNet、GCN、SGC、GAT、GraphSAGE、GIN、ClusterGCN、GraphSAINT等等。此外还有一些模型和框架,比如说GAE、JK-Nets等等,都可以直接去使用。

对于图级的任务,还有很多Pooling的方法。官方文档的教程也比较详细,在此就不过多赘述。

有一个点需要去注意,就是src_node和target_node的问题。默认的flow是src_to_target。如图所示,j是src_node,代表的是邻居节点,i是target_node,代表的是中心节点,和平常的习惯好像不太一样,需要去注意一下。

j
i

图数据集

PyG提供很多已经封装好了的图数据集。下面是几个例子。

在这里插入图片描述

OGB数据集也可以使用from ogb.nodeproppred import PygNodePropPredDataset, Evaluator进行导入为PyG数据集的格式,具体操作可以取参考OGB官网的讲解。

此外,还可以通过重写process方法和使用transforms方法,自定义PyG数据集,用法也比较灵活。不过现阶段做实验基本上直接用PyG封装好的数据集或OGB就行了。

模型的训练与测试

结合毕设,近期主要在关注节点分类任务。

数据集的划分方面有固定划分和随机划分2种,一般来说固定划分效果会更好一点,引文网络数据集一般是默认每个类去20个节点,用平均正确率来评估模型。

关于训练时候的minibatch,PyG里面提供的Dataloader是针对多个图的minibatch(每个batch包含n个图)。也就是说,如果数据集本身就只有一个全联通的图,那么用不用PyG提供的minibatch都是一样的full-batch,我把它称为图级的minibatch。

与之区别的针对节点的minibatch,其实说白了也就是节点采样。在PyG中是使用的NeighborSampler来实现的,这个才是真正的节点级的minibatch。

在这里插入图片描述

官方文档与代码地址

pytorch_geometric:https://github.com/rusty1s/pytorch_geometric
pytorch_scatter:https://pytorch-scatter.readthedocs.io/en/latest/
模型代码实例:https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
官方文档:https://pytorch-geometric.readthedocs.io/en/latest/

猜你喜欢

转载自blog.csdn.net/weixin_41650348/article/details/113751425