DPGN: Distribution Propagation Graph Network for Few-shot Learning

论文:《DPGN: Distribution Propagation Graph Network for Few-shot Learning》,CVPR2020

代码:https://github.com/megvii-research/DPGN

一、概述

在给定少量标注数据(support集)的情况下,Few-shot learning旨在对未标注数据(query 集)进行预测。

有很多方法可以用于Few-shot learning任务,比如:

  • 微调(Fine-tuning)方法,但容易过拟合
  • 元学习(Meta-Learning)方法,但通常隐式利用样本全局关系
  • 图网络(Graph Networks)方法,但只考虑了样本对关系,忽略了重要的分布关系

如上图所示,该论文提出了DPGN(Distribution Propagation Graph Network)模型,通过未标注数据和已标注数据之间的相似度分布,引导标签信息在图中更好地传播。该模型包含点图(Point Graph, PG)和分布图(Distribution Graph, DP)两个完全图,分别用于建模每个样本的实例级别表示和分布级别表示。具体的含义可以看方法部分。

总的来说,论文的创新点有三点:

  1. DPGN是第一个显式利用分布进行标签传播的图网络Few-shot learning方法。
  2. 提出了双完全图架构,结合了实例级别和分布级别的关系。
  3. 在四个Few-shot learning数据集上进行了实验,在分类任务上提升了5%12%的性能,并在半监督任务中提升了7%13%的性能。

二、方法

首先介绍Few-shot learning的问题定义,然后详细介绍DPGN模型的细节。

1 问题定义

每个Few-shot learning任务都有一个support集\(\mathcal{S}\)和一个query集\(\mathcal{Q}\),二者都属于训练集\(\mathbb{D}^{train}\)\(\mathcal{S}=\{(x_1,y_1),\dots,(x_{N\times K},y_{N\times K})\}\)含有\(N\)个类别,每个类别有\(K\)个样本(也就是\(N\)-way \(K\)-shot),\(\mathcal{Q}=\{(x_{N\times K + 1},y_{N\times K + 1}) \dots, (x_{N\times K + \bar{T}},y_{N\times K + \bar{T}})\}\)含有\(\bar{T}\)个样本。在训练阶段,support集和query集的标签都是已知的。在测试阶段,模型需要根据测试集中的support集预测测试集中query集的标签。

2 DPGN

上图展示了DPGN模型的主要过程,该模型包含\(l\)层,每层包含一个点图(PG) \(G_l^p=(V_l^p, E_l^p)\)和一个分布图(DG) \(G_l^d=(V_l^d, E_l^d)\)。每一层的表示计算顺序基本构成一个环,即\(E_l^p \rightarrow V_l^d \rightarrow E_l^d \rightarrow V_l^p \rightarrow E_{l+1}^p\)

为了进一步说明,节点集合\(V_l^p, V_l^d\)分别表示为\(V_l^p=\{v_{l,i}^p\}\)\(V_l^d=\{v_{l,i}^d\}\),边集合\(E_l^p, E_l^d\)分别表示为\(E_l^p=\{e_{l,ij}^p\}\)\(E_l^d=\{e_{l,ij}^d\}\),其中\(i,j=1,\cdots,T\)\(T=N\times K + \bar{T}\)

\(v_{0,i}^p\)被初始化为特征提取器的输出:

\[v_{0,i}^p = f_{emb}(x_i) \in \mathbb{R}^m \]

2.1 点到分布聚合

2.1.1 点相似度

PG中的每条边都表示实例(点)之间的相似度,也就是样本之间的相似度。

\(l=0\)时,PG的边定义为:

\[e_{0,ij}^p=f_{e_0^p}((v_{0,i}^p - v_{0,j}^p)^2) \in \mathbb{R} \]

其中\(f_{e_0^p}:\mathbb{R}^m \rightarrow \mathbb{R}\)用于将向量映射为标量,论文使用两个Conv-BN-ReLU块实现。

\(l \gt 0\)时,PG的边更新规则如下:

\[e_{l,ij}^p=f_{e_l^p}((v_{l-1,i}^p - v_{l-1,j}^p)^2) \cdot e_{l-1,ij}^p \in \mathbb{R} \]

在实际应用中还要对\(e_{l,ij}^p\)进行归一化。

2.1.2 P2D聚合

生成了PG中的边后,下一步就是生成DG中的节点表示。方法如上图所示,DG中每个节点都是维度为\(N\times K\)的特征向量,其中第\(j\)维表示该实例\(x_i\)与实例\(x_j\)的关系,\(N\times K\)就是support集大小。

\(l=0\)时,DP的节点定义为:

\[v_{0,i}^d = \begin{cases} ||_{j=1}^{NK} \delta(y_i, y_j) \quad \text{if} \ x_i \ \text{is labeled} \\ [\frac{1}{NK},\cdots, \frac{1}{NK}] \quad \text{otherwise} \end{cases} \in \mathbb{R}^{NK} \]

其中\(||\)表示连接操作,\(\delta\)输出0或1表示标签\(y_i\)\(y_j\)是否相等。

\(l \gt 0\)时,DG的节点更新规则如下:

\[v_{0,i}^d = P2D(||_{j=1}^{NK} e_{l,ij}^p, v_{l-1,i}^p) \]

其中,\(P2D: (\mathbb{R}^{NK}, \mathbb{R}^{NK}) \rightarrow \mathbb{R}^{NK}\)是聚合网络,论文使用全连接层加ReLU层实现。

2.2 分布到点聚合

2.2.1 分布相似度

DG中每条边表示实例分布特征的相似度,也就是样本在分布空间的相似度。

\(l=0\)时,DG的边定义为:

\[e_{0,ij}^d = f_{e_0^d}((v_{0,i}^d - v_{0,j}^d)^2) \in \mathbb{R} \]

其中,\(f_{e_0^d}: \mathbb{R}^{NK} \rightarrow \mathbb{R}\)用于将向量映射为标量,论文使用两个Conv-BN-ReLU块实现。

\(l \gt 0\)时,DG中边更新规则如下:

\[e_{l,ij}^d = f_{e_l^d}((v_{l,i}^d - v_{l,j}^d)^2) \cdot e_{l-1,ij}^d \in \mathbb{R} \]

同样需要对\(e_{l,ij}^d\)进行正则化。

2.2.2 D2P聚合

接下来就是利用DG中的边特征,也就是样本的分布相似度,生成PG中的节点特征:

\[v_{l,i}^p = D2P(\sum_{j=1}^T(e_{l,ij}^p \cdot v_{l-1,j}^p), v_{l-1,i}^p) \in \mathbb{R}^m \]

其中,\(D2P: (\mathbb{R}^m, \mathbb{R}^m) \rightarrow \mathbb{R}^m\),论文使用两个Conv-BN-ReLU块实现。

3 训练

为了进行节点分类,只需要将最后一层的边特征输入softmax函数即可:

\[P(\hat{y_i}|x_i) = \text{Softmax}(\sum_{j=1}^{NK}e_{l,ij}^p \cdot one\_hot(y_j)) \]

其中,\(P(\hat{y_i}|x_i)\)就是样本\(x_i\)的预测概率分布,\(y_j\)是support集中第\(j\)个样本的标签,\(e_{l,ij}^p\)表示DPGN最后一层PG中的边特征。

3.1 点损失

点损失就是对节点进行分类的交叉熵损失:

\[\mathcal{L}_l^p = \mathcal{L}_{CE}(P(\hat{y_i}|x_i),y_i) \]

其中,\(\mathcal{L}_{CE}\)是交叉熵函数,\(y_i\)\(x_i\)的标签。

3.2 分布损失

分布损失实际上是在DG层面做节点分类:

\[\mathcal{L}_l^d = \mathcal{L}_{CE}(\text{Softmax}(\sum_{j=1}^{NK}e_{l,ij}^d \cdot one\_hot(y_j)),y_i) \]

模型最终的损失函数由每一层的两部分损失得到:

\[\mathcal{L} = \sum_{l=1}^{\hat{l}}(\lambda_p \mathcal{L}_l^p + \lambda_d \mathcal{L}_l^d) \]

其中\(\hat{l}\)表示DPGN总的层数,\(\lambda_p,\lambda_d\)是权重参数。

三、实验

论文使用了四个Few-shot learning数据集

下面展示一个数据集的实验结果,其他数据集结果可以参照原论文

猜你喜欢

转载自www.cnblogs.com/weilonghu/p/12638808.html