【论文下饭】GATE: Graph CCA for Temporal Self-Supervised Learning for Label-Efficient fMRI Analysis

GATE: Graph CCA for Temporal Self-Supervised Learning for Label-Efficient fMRI Analysis

GATE: Graph CCA for Temporal Self-Supervised Learning for Label-Efficient fMRI Analysis
时间:2022
引用:4
期刊会议:IEEE Transactions on Medical Imaging

请添加图片描述

III. METHOD

A.Multi-view fMRI dynamic functional connectivity generation

Dynamic functional connectivity

  • 通过BOLD信号得到FC。
  • FC的上三角展平得到 x i x_i xi,代表了第 i i i个受试者的特征。
  • 构图 G = { X , A } G = \{X, A\} G={ X,A},其中 X X X代表了受试者的特征, A A A代表了受试者之间的相关性。通过Parisot etal. [5]使用KNN获得 A A A

[5] S. Parisot, S. I. Ktena, E. Ferrante, M. Lee, R. Guerrero, B. Glocker, and D. Rueckert, “Disease prediction using graph convolutional networks: application to autism spectrum disorder and alzheimer’s disease,” Medical Image Analysis, vol. 48, pp. 117–130, 2018

总结:节点是受试者,节点特征是FC上三角展平,边是通过KNN计算得到的。

Step window augmentation (S-A)
S-A认为相邻两个窗口的FC是相关视图(related views)。固定窗口长度 L L L,可以得到 M = ⌊ T − L s ⌋ + 1 M=\lfloor {T-L \over s} \rfloor + 1 M=sTL+1个子窗口。
在训练时,S-A随机选择一个子窗口(设为第 m m m个)作为一个视图 G a = { X m , A m } G^a = \{X^m, A^m\} Ga={ Xm,Am},使用其相邻的子窗口作为第二个视图 G b = { X m + 1 , A m + 1 } G^b = \{X^{m+1}, A^{m+1}\} Gb={ Xm+1,Am+1}

Multi-scale window augmentation (M-A)
不同长度窗口的FC包含了相关的信息(relevant information)。
M-A 考虑两个不同长度窗口的FC作为相关视图。

B. Graph embedding

使用两个共享参数的GCN。
得到两个embedding matrices。 Z a = f ( X a , A a ) Z^a = f(X^a, A^a) Za=f(Xa,Aa) Z b = f ( X b , A b ) Z^b = f(X^b, A^b) Zb=f(Xb,Ab)

C. Objective function

对比学习损失[41][42]使用正样本对 和 负样本来优化模型。但是这并不适合有限数量样本和少数量的类(limited number of samples and small number of classes)。

基于重构损失的SSL方法[16],拟合 低信噪比 fMRI特征可能会导致 过拟合 虚假特征。

GATE提出避免采样负样本 或者 是重构fMRI时间序列。

L = − 1 N ∑ i = 1 N ⟨ z i a , z i b ⟩ ∥ z i a ∥ ∥ z i b ∥ + γ ∑ v = a , b ∥ ( Z v ) T Z v − I ∥ F 2 \mathcal{L} = - {1 \over N} \sum_{i=1}^N { \langle z_i^a, z_i^b \rangle \over \lVert z_i^a \rVert \lVert z_i^b \rVert } + \gamma \sum_{v=a, b}\lVert (Z^v)^{\mathrm{T}} Z^v - I \rVert_F^2 L=N1i=1Nziazibzia,zib+γv=a,b∥(Zv)TZvIF2

第一项是regularization of the embedding features,确保低维特征可以保持表征能力。
第二项是确保每个维度特征相互无关(uncorrelated),来避免模式坍塌,比如说模型的所有输出都相同。

总结:相当于每个受试者不同视图的特征 z i z_i zi最大化余弦相似度,同时确保不同受试者的特征 z i z_i zi相互无关。

[16] He K, Chen X, Xie S, et al. Masked autoencoders are scalable vision learners[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 16000-16009.
[41] Veličković P, Fedus W, Hamilton W L, et al. Deep graph infomax[J]. arXiv preprint arXiv:1809.10341, 2018.
[42] Hassani K, Khasahmadi A H. Contrastive multi-view representation learning on graphs[C]//International conference on machine learning. PMLR, 2020: 4116-4126.

Performing downstream task
GCN encoder经过训练之后,就没有必要保持大规模图结构 进行fine-tune 阶段。因为许多GNN encoder属于transductive learning。这意味着随着更多样本的加入 A A A需要重新构建。
因此,微调encoder时没有图信息(使用单位矩阵 I I I替换 A A A),接着 linear layer和ELU激活函数。在小规模已经标记的数据上使用 交叉熵损失 计算特定的预测任务。
在推理时,所有的滑动窗口都进入模型。

请添加图片描述


CCA-SSG训练方法:

  1. 在所有节点上进行训练(无监督)。
  2. transductive learning。冻住encoder,仅训练分类头,损失为交叉熵损失。计算训练集节点的损失。

Zhang H, Wu Q, Yan J, et al. From canonical correlation analysis to self-supervised graph neural networks[J]. Advances in Neural Information Processing Systems, 2021, 34: 76-89.


V. EXPERIMENTS

A. Experiment setup

2) Graph construction:
根据[5]得到初始图 A A A。我们从医学图像中提取了低维、可区分的特征。计算得到相似性矩阵 S ∈ R n × n S\in \R^{n\times n} SRn×n,其中 n n n是polulation graph的节点数量,目的是为了限制高维特征的有害影响(adverse influence),比如说噪声,冗余特征以及维度诅咒。

另外,我们采用了phenotype data(比如说,性别、年龄和基因)从另一个角度计算节点相似性。使用了这些信息可以帮助我们生成高质量图。

最后,我们整合edge from imaged-based node features和edge from phenotype information得到初始图 A A A,这是通过两个矩阵哈达玛积得到的。

接着,稀疏化 A A A。仅保留前 k k k大的边,其他的值置为0。最后,我们再加上对角矩阵 I I I I + A → A I + A \rightarrow A I+AA)。

4) Implementation details:
窗口长度 L L L

  • 在S-A中取30、15。
  • 在M-A中随机从[10, 20, 30, 40, 50]中选取。

实验结果展示

性能比较

均值(标准差)。使用20%的标记数据。“gray” 行表示原始的训练结果(自监督 或 有监督)。“Avg” 表示五个评估标准的平均数。

请添加图片描述

不同比例标签数据影响/自监督影响

请添加图片描述

数据增强的影响

DA表示drop augmentation[47]。

[47] Thakoor S, Tallec C, Azar M G, et al. Large-scale representation learning on graphs via bootstrapping[J]. arXiv preprint arXiv:2102.06514, 2021.

请添加图片描述

自监督学习损失函数比较

CL:将损失函数替换为InfoNCE loss[14],并随机挑选负样本。
RE:将损失函数替换为MSE loss并增加decoder。

[14] Oord A, Li Y, Vinyals O. Representation learning with contrastive predictive coding[J]. arXiv preprint arXiv:1807.03748, 2018.

请添加图片描述

embedding维度比较

差不多128就好了。

请添加图片描述

损失函数中超参数 γ \gamma γ影响

差不多0.2就好了。

请添加图片描述

Fine-tune与图结构影响

w/o Fine-tune:无fine-tune步骤。
w/o Graph:无图结构(将 A A A替换为 I I I)。

请添加图片描述


图9没看懂。。


滑动窗口大小 和 步长影响

S-A中滑动窗口大小 和 步长影响:

窗口大小 L L L和步长 s s s分别为[30, 60] 和 [10,35]比较好。

请添加图片描述


M-A中滑动窗口大小 和 步长影响:

两个view的窗口大小在[10, 40] 比较好。

请添加图片描述

猜你喜欢

转载自blog.csdn.net/LittleSeedling/article/details/133845789