【论文下饭】Temporal Graph Network for Deep Learning on Dynamic Graphs


综述
Representation Learning for Dynamic Graphs: A Survey

知识点
Transductive\inductive

transductive、inductive理解
Transductive和Inductive


Paper: Temporal graph network for deep learning on dynamic graphs
Cite: Rossi E, Chamberlain B, Frasca F, et al. Temporal graph networks for deep learning on dynamic graphs[J]. arXiv preprint arXiv:2006.10637, 2020.

中文参考:
内容比较好:TGN:Temporal Graph Networks for Deep Learning on Dynamic Graphs
格式比较好:TGN: TEMPORAL GRAPH NETWORKS FOR DEEP LEARNING ON DYNAMIC GRAPHS论文笔记


1 介绍

图表示学习 取得了一系列的成功。
图 普遍地用于 关系 和 交互 系统的建模,比如说 社交网络、生物网络。
在这些网络上,普遍使用GNN。GNN通过 信息传递机制 聚合邻居的信息,得到该节点的嵌入向量。之后,便可用于节点分类、图分类、边集预测任务上。

大部分在 图上的深度学习算法 都有一个前提假定——图是静态的。
但是,大多数现实生活的交互系统,比如说 社交网络、生物网络,图都是动态的。
通过忽略动态图的时序特征,使用 静态图深度学习方法,也是可以的。但是,这一般是次优解。因为,在某些情况下,模型会忽略掉一些动态图的关键特征

在动态图上的研究也是今年才兴起的,大多数研究都局限于 离散时间图(discrete-time dynamic graph)。当动态图是连续的(边可以在任何时间出现)、进化的(点可以连续地加入到图钟)时,上面提到的方法都不适合。

直到最近,有很多方法提出 说 支持 连续时间图(continuous-time dynamic graph)。

本文的贡献

  • 提出了适用于 连续时间图 的 generic inductive framework Temproal Graph Networks(TGNs)。在本文之前的很多方法,都可以看作是TGNs的一个特例。
  • 提出了一个高效的训练策略,使得模型能够从时序数据中实现高效的并行处理。
  • 做了很多详细的 消融实验,分析了 本文模型各个组件的性能。
  • 本文的模型在许多(both transductive and inductive)任务上取得了SOTA表现,并且 速度比之前的方法快。

2 背景

在静态图上的深度学习
主要讲的就是GNN(这里跳过)

动态图
动态图有两种分类。

  • Discrete-time dynamic graphs 离散时间动态图DTDG
    它是一系列的 一段时间内的 动态图的快照(snapshots)。

  • Continuos-time dynamic graphs 连续时间动态图CTDG
    它是动态图更加一般性的表示。由一系列的时间组成(timed lists of events)。这些事件包括 边的增加/删除、节点的添加/删除、节点/边的特征变化

本文在正文部分 使用节点/边的添加 作为例子。(节点/边删除 在附录讨论)

3 Temporal Graph Network

TGN的介绍。encoder-decoder模型。

3.1 核心模块

在这里插入图片描述

Memory

对于每个节点 i i i,在时刻 t t t,都有一个向量 s i ( t ) s_i(t) si(t)表示为该节点的记忆单元。它记录了节点 i i i [ 0 , t ] [0,t] [0,t]时间内记忆。它代表了节点的“历史”。

有了记忆模块,TGNs就可以记录 每个节点在图中 长期的依赖关系。

  • 当新的节点被加入时,它的记忆单元初始化一个零向量。
  • 每个事件过后,相关节点的记忆单元就会被更新。

也可以使用全局的记忆单元,来记录整个图的变化,但是为了简单起见,这留作未来的工作。

Message Function(msg)

对于一个 交互事件 e i j ( t ) e_{ij}(t) eij(t),有两个方向的信息。
m i ( t ) = m s g s ( s i ( t − ) , s j ( t − ) , Δ t , e i j ( t ) ) m j ( t ) = m s g d ( s j ( t − ) , s i ( t − ) , Δ t , e i j ( t ) ) \begin{aligned} &m_i(t) = \mathrm{msg_s} (s_i(t^-),s_j(t^-),\Delta t,e_{ij}(t) ) \\ &m_j(t) = \mathrm{msg_d} (s_j(t^-),s_i(t^-),\Delta t,e_{ij}(t)) \end{aligned} mi(t)=msgs(si(t),sj(t),Δt,eij(t))mj(t)=msgd(sj(t),si(t),Δt,eij(t))
其中, s i ( t − ) s_i(t^-) si(t)为节点 i i i t t t时刻前的记忆单元。 m s g msg msg为可以学习的信息传递函数,比如说MLP。

在本文中,使用 i d e n t i t y ( i d ) identity(id) identity(id)作为信息传递函数 m s g msg msg。(即,简单地对 m s g msg msg的输入进行concate)

Message Aggregator(agg)

在使用批次(batch)处理时,有一些节点的 m s g msg msg会被多次使用。
出于性能考虑,本文提出了 信息聚合机制
m ˉ i ( t ) = a g g ( m i ( t 1 ) , . . . , m i ( t b ) ) \bar m_i(t) = \mathrm{agg}(m_i(t_1),...,m_i(t_b)) mˉi(t)=agg(mi(t1),...,mi(tb))
其中, t 1 , . . . , t b t_1,...,t_b t1,...,tb为节点 i i i 在相同批次中的时间序列。 a g g \mathrm{agg} agg 是聚合函数,比如说RNN或attention机制。

简单地说,就是把 同一批次中所有节点 i i i m s g msg msg聚合到一起。

在本文中,使用most recent message只保留最近的信息。mean message计算所有信息的平均值。

Memory Updater(mem)

对于每个事件涉及到的节点,需要更新其记忆单元:
s i ( t ) = m e m ( m ˉ i ( t ) , s i ( t − ) ) s_i(t) = \mathrm{mem}(\bar m_i(t),s_i(t^-)) si(t)=mem(mˉi(t),si(t))
其中, m e m \mathrm{mem} mem可以学习的更新函数。比如说,循环神经网络LSTM、GRU。

在本文中,使用 G R U \mathrm{GRU} GRU作为记忆更新函数 m e m \mathrm{mem} mem

Embedding(emb)

向量嵌入模块 可以生成 每个节点 i i i t t t时刻的 时序嵌入向量 z i ( t ) z_i(t) zi(t)

记忆过期问题
节点 i i i的记忆单元更新,当且仅当 存在事件 包含节点 i i i。当某个节点 i i i的记忆单元长时间得不到更新时,节点 i i i的记忆单元就可以被认为 过期(stale) 了。

比如,在社交网络中,某个用户 长时间 不使用该平台 后,又再次使用。

嵌入向量计算的统一形式如下:
z i ( t ) = e m b ( i , t ) = ∑ j ∈ n i k ( [ 0 , t ] ) h ( s i ( t ) , s j ( t ) , e i j , v i ( t ) , v j ( t ) ) z_i(t) = \mathrm{emb}(i,t) = \sum_{j \in n^k_i([0,t])}h(s_i(t),s_j(t),e_{ij},\bold{v}_i(t),\bold{v}_j(t)) zi(t)=emb(i,t)=jnik([0,t])h(si(t),sj(t),eij,vi(t),vj(t))
其中, h h h是可以学习的函数。它可以有多种实现方式,比如说:
Identity(id)
e m b ( i , t ) = s i ( t ) \mathrm{emb}(i,t) = s_i(t) emb(i,t)=si(t),直接使用节点的记忆单元。

Time projection(time)
e m b ( i , t ) = ( 1 + Δ t   w ) ∘ s i ( t ) \mathrm{emb}(i,t) = (1+\Delta t ~\bold{w}) \circ s_i(t) emb(i,t)=(1+Δt w)si(t),其中, w \bold{w} w 是可以学习的参数, Δ t \Delta t Δt 是距上一次交互的时间间隔, ∘ \circ element-wise向量乘积。(该方法使用于Joide模型中(Kumar etal., 2019))。

Temporal Graph Attention(attn)

Firstly proposed in TGAT(Xu et al., 2020)

L L L 层的图注意力机制,可以利用节点 i i i L L L跳的时序邻居信息,计算(节点 i i i)的嵌入向量。

节点 i i i,在 t t t时刻,第 l l l层的输入是 h i ( l − 1 ) ( t ) h_i^{(l-1)}(t) hi(l1)(t),节点 i i i的邻居表示 { h 1 l − 1 ( t ) , . . . , h N l − 1 ( t ) } \{h_1^{l-1}(t),...,h_N^{l-1}(t) \} { h1l1(t),...,hNl1(t)},特征为 e i 1 ( t 1 ) , . . . , e i N ( t N ) e_{i1}(t_1),...,e_{iN}(t_N) ei1(t1),...,eiN(tN)

注意:因为训练是按批次(batch)进行的,特征 e e e的发生时刻可能不同。

h i ( l ) ( t ) = M L P ( l ) ( h i ( l − 1 ) ( t ) ∥ h ~ i ( l ) ( t ) ) , h ~ i ( l ) ( t ) = M u l t i H e a d A t t e n t i o n ( l ) ( q ( l ) ( t ) , K ( l ) ( t ) , V ( l ) ( t ) ) , q ( l ) ( t ) = h i ( l − 1 ) ( t ) ∥ ϕ ( 0 ) , K ( l ) ( t ) = V ( l ) ( t ) ) = C ( l ) ( t ) , C ( l ) ( t ) = [ h 1 ( l − 1 ) ( t ) ∥ e i 1 ( t 1 ) ∥ ϕ ( t − t 1 ) , . . . , h N ( l − 1 ) ( t ) ∥ e i N ( t N ) ∥ ϕ ( t − t N ) ] \begin{aligned} &\bold{h}_i^{(l)}(t) = \mathrm{MLP}^{(l)}(\bold{h}_i^{(l-1)}(t) \parallel \tilde{h}_i^{(l)}(t)), \\ &\tilde{\bold{h}}_i^{(l)}(t) = \mathrm{MultiHeadAttention}^{(l)}(\bold{q}^{(l)}(t),\bold{K}^{(l)}(t),\bold{V}^{(l)}(t)),\\ &\bold{q}^{(l)}(t) = \bold{h}_i^{(l-1)}(t) \parallel \phi(0),\\ &\bold{K}^{(l)}(t) = \bold{V}^{(l)}(t)) = \bold{C}^{(l)}(t), \\ &\bold{C}^{(l)}(t) = [\bold{h}_1^{(l-1)}(t) \parallel \bold{e}_{i1}(t_1) \parallel \phi(t-t_1),...,\bold{h}_N^{(l-1)}(t) \parallel \bold{e}_{iN}(t_N) \parallel \phi(t-t_N)] \end{aligned} hi(l)(t)=MLP(l)(hi(l1)(t)h~i(l)(t)),h~i(l)(t)=MultiHeadAttention(l)(q(l)(t),K(l)(t),V(l)(t)),q(l)(t)=hi(l1)(t)ϕ(0),K(l)(t)=V(l)(t))=C(l)(t),C(l)(t)=[h1(l1)(t)ei1(t1)ϕ(tt1),...,hN(l1)(t)eiN(tN)ϕ(ttN)]
其中, ϕ \phi ϕ 是一个 通用的时序编码器(generic time encoding), ∥ \parallel 是concate操作,最后得到的嵌入向量为 z i ( t ) = e m b ( i , t ) = h i ( L ) ( t ) \bold{z}_i(t) = \mathrm{emb}(i,t) = \bold{h}_i^{(L)}(t) zi(t)=emb(i,t)=hi(L)(t) q ( l ) ( t ) \bold{q}^{(l)}(t) q(l)(t)是可以是节点 i i i或节点 i i i L − 1 L-1 L1跳邻居。 K ( l ) ( t ) \bold{K}^{(l)}(t) K(l)(t) V ( l ) ( t ) \bold{V}^{(l)}(t) V(l)(t)是节点 i i i的邻居。

简单地说,就是一个多头注意力机制,重点在 C \bold{C} C中,把时序特征一并输入。

特别地,与TGAT中提到的不同的是,在第 0 0 0层时,本文考虑了节点本身的特征 v ( t ) \bold{v}(t) v(t)node-wise temporal features),即 h j ( 0 ) ( t ) = s j ( t ) + v j ( t ) h_j^{(0)}(t) = s_j(t)+\bold{v}_j(t) hj(0)(t)=sj(t)+vj(t)。这使得模型可以同时利用 现有的记忆 s j ( t ) s_j(t) sj(t)和时序节点特征 v j ( t ) \bold{v}_j(t) vj(t)

Temporal Graph Sum(sum)
简单、快速的聚合方法。
h i ( l ) ( t ) = W 2 ( l ) ( h i ( l − 1 ) ( t ) ∥ h ~ i ( l ) ( t ) ) , h ~ i ( l ) ( t ) = R e L u ( ∑ j ∈ n i ( [ 0 , t ] ) W 1 ( l ) ( h j ( l − 1 ) ( t ) ∥ e i j ∥ ϕ ( t − t j ) ) ) \begin{aligned} &\bold{h}_i^{(l)}(t) = \bold{W}_2^{(l)}(\bold{h}_i^{(l-1)}(t) \parallel \tilde{\bold{h}}_i^{(l)}(t)),\\ &\tilde{\bold{h}}_i^{(l)}(t) = \mathrm{ReLu}(\sum_{j \in n_i([0,t])} \bold{W}_1^{(l)} (\bold{h}_j^{(l-1)}(t) \parallel \bold{e}_{ij} \parallel \phi(t-t_j))) \end{aligned} hi(l)(t)=W2(l)(hi(l1)(t)h~i(l)(t)),h~i(l)(t)=ReLu(jni([0,t])W1(l)(hj(l1)(t)eijϕ(ttj)))

同样地, ϕ \phi ϕ 是一个 通用的时序编码器(generic time encoding),最后得到的嵌入向量为 z i ( t ) = e m b ( i , t ) = h i ( L ) ( t ) \bold{z}_i(t) = \mathrm{emb}(i,t) = \bold{h}_i^{(L)}(t) zi(t)=emb(i,t)=hi(L)(t)


图向量嵌入模块 通过聚合邻居的记忆信息,缓和了 (记忆)过期问题,使得TGN可以计算最新的嵌入向量信息。

temporal graph attention使得模型能够 寻找 包含重要特征和时序信息 的邻居节点。

3.2 训练

TGN可以用于许多任务上,比如说 边集预测(自监督)或者 节点分类(半监督)。

我们使用 连接预测作为例子:提供一系列时间排序的交互,我们的目标是从过去的观察中,预测未来可能出现的交互。

交互(interactions):也就是 边。


在这里插入图片描述

之前提到的训练策略在 记忆相关模块中 存在问题——不能直接影响loss,也就是说接收不到梯度

为了解决这个问题,记忆单元(memory)必须在预测交互之前 更新。但是,这就造成了信息泄露(information leakage)。

为了避免这个问题,本文提出了一个额外的模块Raw Message Store,用于存储 batch b b b 的交互信息,原来的Message模块 用于存储 batch b − 1 b-1 b1 的交互信息。
这样,通过添加 缓存,解决了信息泄露的问题。

需要注意的是,我们的 batch size不能选的太大。经过 速度和颗粒度(granularity)的权衡,作者认为batch size=200is good。

因为,当前的 预测 用到的是 上一个batch的交互信息。如果batch太大(极端地说,整个数据集),所有的预测 用到的都是 初始的零向量记忆单元。

4 相关工作

早期的工作 都集中于DTDGs上。
比如说,

  • 聚合 图的快照 然后使用静态的方法
  • 把 图的快照 整合成张量并分解它
  • 编码 每张快照 产生一系列嵌入向量。

另一条编码DTDGs的主线工作:先在初始的快照上使用 随机游走,然后对于子序列快照修改游走行为。

时空图(spatio-temporal graphs)是动态图的特例。因为 时空图的 拓扑逻辑是固定的。


CTDGs。

  • 根据 连续时间(continuos time) 限定 随机游走 的转移概率。
  • CTDGs的序列模型。对每个事件 e i j e_{ij} eij使用RNN来更新来更新 源点 和 宿点的表示(representations)。

许多的架构用的都是基于RNN的node-wise记忆单元。由于缺少GNN的信息聚合机制,使得这些记忆单元可能出现 (记忆)过期问题,同时 它的计算也是十分耗时的。


最新的CTDGs学习模型,都可以看作本文框架TGN的一个特例。

在这里插入图片描述

5 实验

数据集
Wikipedia
Reddit
Twitter

任务:edge prediction(预测两个节点未来出现连接的概率)。同时研究了transductiveinductive设置下的情况。

transductive任务中,预测的连接 在训练时 出现过;在inductive任务中,预测的连接在 训练中 没有出现过。

本文使用的解码器是一个简单的MLP。

Baselines
strong baselines:
CTDNE
Jodie
DyRep
TGAT
GAE
VGAE
DeepWalk
Node2Vec
GAT
GraphSAGE

5.1 性能表现(实验结果)

在这里插入图片描述

在这里插入图片描述

总结:最好的模型TGN-atten很强,而且很快(比TGAT快30倍)。

5.2 模块选择

在这里插入图片描述

Memory

比较的模块:

TGN-no-mem 没有使用记忆模块
TGN-attn 最好的模型

现象:

  • TGN-attnTGN-no-mem慢3倍。
  • TGN-attnTGN-no-mem 准确率提升4%。

结论:

  • 记忆单元 能够帮助 存储节点的长期信息。
  • 采样更多的邻居信息 可以 达到同样(带记忆)的效果。(但花费更多的时间)

Embedding Moudle

比较的模块:

TGN-id(DyRep)
TGN-time(Jodie)
TGN-attn
TGN-sum

现象:

  • TGN-id 优于 TGN-time
  • graph-base的方法(TGN-attn, TGN-sum) 比 graph-less的方法 TGN-id 高出一大截。
  • TGN-attn 仅比 TGN-sum高一点点。

结论:

  • 使用图的最近的信息,选择哪些邻居是最关键的,是很重要的影响因素。

Message Aggregator

比较的模块:

TGN-mean
TGN-attn

现象:

  • TGN-meanTGN-attn好一点。
  • TGN-meanTGN-attn慢3倍。

Number of layers

比较的模块:

TGN-2l
TGN-attn

比较使用到的GNN层数,因为在TGAT中,两层比一层 好了不止10%。

现象:

  • TGN-2l 仅比 TGN-attn 高一点。

结论:

  • 由于使用到了记忆单元TGN-attn仅使用1层就可以达到不错的效果。
  • 当使用1-hop邻居的记忆单元时,我们 间接地 使用了 比1-hop更远的信息。

6 结论

  • TGN:通用的 时间连续图 深度学习框架。之前的一些工作,都可以看作是本文框架的一个特例。
  • TGN可以做到SOTA。
  • 本文对每个模块做了详细的消融实验。记忆模块(可以存储 长期信息),嵌入模块(生成最新的节点嵌入向量)很重要。

猜你喜欢

转载自blog.csdn.net/LittleSeedling/article/details/124463497