【Paper Introduction】- STFL: A Spatial-Temporal Federated Learning Framework for Graph Neural Networks

Paper information

STFL: A Spatial-Temporal Federated Learning Framework for Graph Neural Networks
insert image description here

Original address: STFL: A Spatial-Temporal Federated Learning Framework for Graph Neural Networks: https://arxiv.org/abs/2111.06750
Source code: https://github.com/JW9MsjwjnpdRLFw/TSFL

Summary

We present a spatial-temporal federated learning framework for graph neural networks, namely STFL. The framework explores the underlying correlation of the input spatial-temporal data and transform it to both node features and adjacency matrix. The federated learning setting in the framework ensures data privacy while achieving a good model generalization. Experiments results on the sleep stage dataset, ISRUC_S3, illustrate the effectiveness of STFL on graph prediction tasks.

We propose STFL, a spatio-temporal federated learning framework for graph neural networks. The framework mines the underlying correlations of the input spatio-temporal data and transforms them into node features and adjacency matrix. The federated learning setting in the framework achieves good model generalization while ensuring data privacy. Experimental results on the sleep stage dataset ISRUC S3 illustrate the effectiveness of STFL on graph prediction tasks.

Contributions

  1. We first implement a graph generator for processing spatiotemporal data, including feature extraction and node correlation exploration;
  2. Integrating the graph generator into the proposed STFL, an end-to-end federated learning framework of spatio-temporal GNNs on graph-level classification tasks is designed;
  3. Extensive experiments were conducted on the real sleep dataset ISRUC S3;
  4. Publish the source code of STFL on Github1.

Methodology

STFL framework:
insert image description here

Graph Generation

Treat the space-time series as raw input. Definition A multivariate sequence insert image description hereis defined as the set of time series with a total of T timestamps, each of which has si ∈ RDdimension signal frequency. Since there is no node concept in spatiotemporal data, we leverage spatial channels and treat them as nodes, which means that if there are N channels, there will be N nodes in the transformed graph data structure.

Assuming each channel has a time series set S, the space-time series with complete channels is denoted as insert image description here.
Afterwards, the original spatio-temporal data is converted into a feature matrix representation using a CNN-based feature extraction network, and the output of the feature extraction network is insert image description here, where d represents the dimension of the feature. insert image description hereA snapshot of is represented as insert image description here.
After obtaining the refined feature matrix insert image description here, the correlation between channels (nodes) needs to be revealed. At this point it is natural to process XT ∈ RN×dasthe node feature matrix and retrieve potential correlations between them. Thereafter, we define the node correlation function, which takes a node feature matrix as input, and outputs an adjacency matrix AT∈RN×N::
insert image description here
where Corr( ) computes the correlation or dependency of each channel (node) on the basis of XT . There are several options for the nodal correlation function, such as the Pearson correlation function or the phase-locked value function.

Graph Neural Network

Along the time dimension, we get {G1, ..., GT} as the whole graph dataset, denoting the graph data generated at each timestamp, and we use {y1, ..., yT} to correspond to graph labels. We here formulate the graph prediction task, where the output of the graph generator is expected to be correctly predicted. For the simplicity of the notation, we use VT to represent the node set in each GT, and the number of nodes V is basically the same as the row number in the node feature matrix XT. For each v ∈ V, the corresponding node features are written as xv∈ Rd.

We use ne[v] to denote the neighborhood of node v, whose associated values ​​can be retrieved from the adjacency matrix A. Then, we formulate the message passing and readout stages of GNN. Let coachdenote the node embeddings in layer l. The message passing of node v from layer l to layer l+1 can be formalized as:
insert image description herewhere, insert image description hererepresents the learnable transformation matrix of layer l+1, and σ represents the activation function. GNNs update hl 1vembedded .

To obtain a representation of the entire graph after the L-level message passing layer, GNN performs a readout operation to derive the final graph representation from node embeddings, which can be formulated as follows: Readout( ) is
insert image description here
a permutation-invariant operation, which can be simply The mean function can also be a more complex graph-level pooling function, such as MLP.
In the fully supervised setting, we use a shallow neural network to learn a mapping between graph embeddings and the label space Y. σ( ) is a nonlinear transformation that can be generalized as: insert image description here
Furthermore, we utilize a graph-based binary cross-entropy function to compute the loss L in the supervised setting. The loss function formula is:
insert image description here

federated learning

STFL trains GNNs from different clients under a federated learning setting. STFL consists of a central server S and n clients C. Each client deploys a GNN that learns the client from local graph data and uploads the GNN's weights to a central server. The central server receives weights from all clients, updates the weights WS of the global GNN model, and distributes the updated weights back to each client. In this work, we choose FedAvg as the aggregation function, which averages the weights of each client to generate the weights of the global GNN on the server.
insert image description here

Experiment

data set

In our experiments, ISRUC S3 (Khalighi et al., 2016) is used as the benchmark dataset. ISRUC S3 collects polysomnographic (PSG) recordings from 10 channels of 10 healthy subjects (ie sleep experiment participants). These PSG recordings were labeled for five distinct sleep stages, including wake, N1, N2, N3, and REM, according to AASM criteria (Jia et al., 2020). As described in the previous section, we employ a CNN-based feature extraction network (Jia et al., 2021) to generate initial node features. To generate the adjacency matrix, four different node association functions are implemented and discussed separately. To evaluate the effectiveness of STFL, we follow the non-iiddata setting (Zhang et al., 2020) and assign different sleep stages to clients to verify the effectiveness of our proposed framework.

Node Correlation Functions

  • DB is the Euclidean distance function used to measure the spatial distance between electrode pairs.
  • K-NN (Jiang et al. 2013) generates an adjacency matrix that selects only the k nearest neighbors of each node to represent the node dependencies of a graph.
  • PCC (Pearson and Lee 1903) is known as the Pearson correlation function and is used to measure the similarity between each pair of nodes.
  • PLV (Aydore, Pantazis, and Leahy 2013) is a time-varying node correlation function that measures the signal of each pair of nodes.

Performance Comparative Analysis

  1. To evaluate the effectiveness of the four node-related functions, we compare the effect of each node-related function on GCN under the federated setting, since GCN has the simplest structure among the three GNN models. As shown in Figure 2, PCC and PLV work well in the joint setting, with faster convergence rates, especially in the first two epochs. In addition, compared with other node-related functions, as shown in Table 2, the F1 scores of PLV of the 3 federated models are the highest, followed by PCC, and DB is the worst. This may be due to the pooling layer in the CNN model (feature extraction network), which looks at a small temporal window of the input sequence, from which the correct correlation for each pair of nodes can be extracted using PLV.
    insert image description hereinsert image description here

  2. To evaluate the effectiveness of STFL, we tested its performance from different perspectives. In our experiments, we first evaluate the federated graph model on ISRUC S3 with PLV, since PLV forms best for each of the four node-related functions discussed in RQ1. As shown in Table 3, under STFL, all three GNN models produce reasonable results. Especially in the joint setting, GAT achieves the highest F1 score and accuracy on PLV, and GraphSage comes in second.
    insert image description hereinsert image description hereFurthermore, we examine the results of the centralized models for these three graph networks, and the results are also shown in Table 3. In this part, the hyperparameters are kept constant with the joint experiments. For data splitting, the test data is the same as the data in the federated learning experiments. The training data is randomly sampled from the aggregated data of all clients, and the size of the training data is the same as that of one client. For all GNNs in the centralized setting, GraphSage achieves the highest F1 score and accuracy, followed by GCN. Furthermore, all models trained under the joint setting achieve better results (F1score and accuracy) compared to the centralized setting. This indicates that models trained under STFL successfully generate data distributions in non-IID settings. Another finding is that the best GNN model in a centralized setting is not necessarily the best in a federated setting.

  3. To determine the best match of GNNs to STFL, three GNNs were tested on ISRUC S3 and PLV under the joint framework, since PLV was observed to achieve the best results among all node-related functions, the details of which are analyzed in RQ1 . As shown in Figure 3, GCN converges the fastest but is more unstable than the other two. We also find that GraphSage converges the slowest in the first epoch, but achieves a steady loss reduction during the testing phase. It also found that all three models eventually converged to the same loss, fluctuating around 0.15. Furthermore, we evaluate the F1 score of each class using PLV. Table 4 shows that for REM, GraphSage performs best, while GCN scores highest in the other four categories. insert image description here
    Interestingly, the training loss of the three models fluctuates in a wide range, especially in the last three epochs. This is probably because the joint framework distributes the global model to each client in each training batch. In the later stages of training, each client cannot fit its own data well in the generalized global model, especially for those models that are prone to overfitting.

Guess you like

Origin blog.csdn.net/weixin_43598687/article/details/131141861