阅读笔记(FedGraphNN: A Federated Learning Benchmark System for Graph Neural Networks)

源码:https://github.com/FedML-AI/FedGraphNN

大多数图学习模型如图神经网络(Graph Neural Networks,GNNs)都是基于大量的图数据进行训练的,然而在许多现实场景中,例如医疗保健系统中的住院预测,图数据通常存储在多个数据所有者处,由于涉及患者的隐私和相关法律法规限制,不同数据所有者的数据不能直接共享。

联邦学习(Federated Learning, FL)是一种分布式学习方案,通过多个参与方(即客户端)协作训练一个模型而无需共享他们的隐私数据,以解决数据隔离问题。图联邦学习(Federated Graph Learning,FGL)是联邦学习在图数据上的应用,通过以联邦方式训练图神经网络,来解决图数据隐私保护的问题。

文章将图联邦学习根据处理对象的不同分为了以下三类:

(a)图级联邦学习:每个参与图联邦学习的客户端都持有一组图数据,例如在生物化学行业中,一个分子可以表示为一个图,其中节点表示原子,边表示化学键,客户端持有的图数据包括多个分子图,典型的应用是图分类。现实场景包括分子试验、蛋白质发现等。

(b)子图级联邦学习:每个客户端拥有整张图的一部分数据(即子图),典型应用为链接预测、节点分类。现实场景包括推荐系统、知识图谱补全等。

(c)节点级联邦学习:每个客户端持有包含一个或多个节点的自我中心网络,对于每个节点来说,只能看到其k-top邻居节点及边。其中典型的任务是节点分类和链路预测。现实场景包括社交网络、传感器网络等。

FedGraphNN:

文章给出了 FGL模型的范式表示。假设在FedGraphNN统一框架中包括一个服务器(FL Server)和K个客户端(FL Client)。每一个客户端拥有一份私密的数据集,并使用 GNN 模型进行本地的训练及预测。每一个客户端借助服务器彼此协作以改进其GNN模型,而无需直接共享数据集。其中,第k个客户端拥有数据集D^{(k)}:=(G^{(k)},Y^{(k)}),而G^{(k)}=(\nu ^{(k)},\varepsilon ^{(k)})表示由节点和边构成的图,并具有节点特征集X^{(k)}=\left \{ x_{m}^{k} \right \}_{m\in \nu^{(k)}}、边特征集Z^{(k)}=\left \{ e_{m,n}^{k} \right \}_{m,n \in \nu ^{(k)}}Y^{(k)}表示图的标签。

FedGraphNN框架下,每一个客户端内的 GNN 模型以消息传递网络(Message Passing Neural Network,MPNN)的形式给出,大多数的基于空域的 GNN 模型都可以用 MPNN 的形式表示。MPNN 模型包括两个阶段:消息传递阶段(message-passing phase)和读出阶段(readout phase),如下图所示。

Phase 1 消息传递阶段:

消息传递阶段包括两个步骤:(1)模型收集和转换邻居的消息;(2)模型使用聚合的消息来更新节点的隐藏状态。

对于客户端k的第l层(l=0,1,...,L-1),一个L层的 MPNN 表示为:

m_{i}^{(k,l+1)}=AGG(\left \{ M_{\theta }^{(k,l+1)}(h_{i}^{(k,l)},h_{j}^{(k,l)},z_{i,j})|j\in N_{i} \right \})

h_{i}^{(k,l+1)}=U_{\varnothing }^{(k,l+1)}(h_{i}^{(k,l)},m_{i}^{(k,l+1)})

其中,h_{i}^{(k,0)}为客户端kl层的节点特征,当l=0时有h_{i}^{(k,0)}=x_{i}^{(k)}AGG(\cdot )代表一种聚合函数,在GCN模型中,该聚合函数即简单的SUM操作。N_{i}表示节点i的邻居节点集,M_{\theta }^{(k,l+1)}是一种消息生成函数,以当前节点h_{i}、相邻节点h_{j}隐藏状态以及边特征z_{i,j}为输入。U_{\varnothing }^{(k,l+1)}为隐层状态更新函数。

 Phase 2 读出阶段:

上一阶段完成了消息传递后,学习到各个节点的隐藏状态,即MPNN的最后一层输出。接下来可根据下游任务的不同,执行不同的读出处理(类似于解码的操作),它可以写作:

\hat{y}_{S}^{(k)}=R_{\delta }(\left \{ h_{i}^{(k,L)}|i\in \nu _{S}^{(k)} \right \})

其中,S可以是单个节点(节点分类任务)、节点对(链接预测任务)或者节点集(图分类任务)。R_{\delta }可以是级联函数或者是池化函数(例如多层感知机MLP)。

FGL:

可从全局角度定义W=\left \{ M_{\theta } ,U_{\varnothing },R_{\delta }\right \}为客户端k的图神经网络模型的所有可训练权重参数,令其第T轮的本地参数为W^{k,T}。所有客户端将自身的本地参数上传到服务器,服务器使用聚合机制得到全局参数W^{(k,T+1)},然后下放回各个本地客户端。因此 FedGraphNN可以看作是一个分布式优化问题,其优化目标为调节W以使F(W)最小化:

 其中,f^{(k)}(W)=\frac{1}{N^{(k)}}\sum_{i=1}^{N^{(k)}}L(W;x_{i}^{(k)},z_{i}^{(k)},y_{i}^{(k)})代表客户端k本地目标函数衡量具有N^{(k)}个数据样本的图数据集D^{(k)}的局部经验风险。L代表全局GNN模型的损失函数。

为解决此优化问题,最直接的方案是FedAvg。在FedAvg中,服务器上的聚合函数只是平均了模型参数。另外,FedAvg是随机选取若干个客户端进行本地更新,但是是对所有客户端的模型进行聚合,而未被选中的客户端的模型参数有w_{t+1}^{(k)}=w_{t}^{(k)}

FedGraphNN 归纳地使用GNN(即模型独立于被训练的图的结构),因此,服务器在参数聚合过程中不需要任何客户端的图拓扑信息。也可以应用其他先进的算法,如 FedOPT、FedGKT 和 Decentralized FL。

支持的GNN模型:GCN、GAT、GraphSage、SGC、GIN。

支持的FL算法:FedAvg、FedOPT。

FedGraphNN架构图:

猜你喜欢

转载自blog.csdn.net/weixin_44458771/article/details/129039522