KAN: KNOWLEDGE-AUGMENTED NETWORKS FOR FEW-SHOT LEARNING

https://ieeexplore.ieee.org/document/9413612

论文

方法

对支持集和查询集的图片用 f ϕ ( ⋅ ) f_\phi(·) fϕ()编码,每张图片对应的输出为 W × H × C W × H × C W×H×C,论文中设置通道数 C = 512 C=512 C=512.

在训练阶段,通过图卷积网络 g ψ ( ⋅ ) g_\psi(·) gψ()对图G进行编码,将知识图中每个节点表示为512维度的向量。 g ψ ( G ) k g_\psi(\mathcal{G})_k gψ(G)k表示类k对应节点的语义信息.

然后将查询集每一类的原型 c k = 1 ∣ S k ∣ ∑ ( x i , y i ) ∈ S k f ϕ ( x i ) c_k= \frac{1}{|S_k|} \sum_{(x_i,y_i)\in S_k}{f_\phi(x_i)} ck=Sk1(xi,yi)Skfϕ(xi)和知识图中输出的512维度向量,增强类的表达形式: e k = 1 2 ( c k + g ψ ( G ) k ) e_k = \frac{1}{2} (c_k+g_\psi(G)_k) ek=21(ck+gψ(G)k).由于 g ψ ( G ) k g_\psi(G)_k gψ(G)k是512维的向量,而 c k c_k ck W × H × C W × H × C W×H×C的矩阵,所以需要将 g ψ ( G ) k g_\psi(G)_k gψ(G)k重复 H × W H×W H×W次.

然后将编码后的查询向量和每个类的表示做像素级别的相似余弦度,查询特征 x q x_q xq在类上的每个位置生成像素级分布:
p ϕ , ψ ( y = k ∣ x q ) = e x p ( d ( e k , f ϕ ( x q ) ) ) ∑ j e x p ( d ( e j , f ϕ ( x q ) ) ) p_{\phi,\psi}(y=k|x_q)=\frac{exp(d(e_k,f_\phi(x_q)))}{\sum_jexp(d(e_j,f\phi(x_q)))} pϕ,ψ(y=kxq)=jexp(d(ej,fϕ(xq)))exp(d(ek,fϕ(xq)))
相似损失函数为:
L S = ∑ b L S b = ∑ b ( − l o g p ϕ , ψ ( y = k ∣ x q ) ) b \mathcal{L}_S =\sum_b\mathcal{L}_S^b= \sum_b(-logp_{\phi,\psi}(y=k|x_q))^b LS=bLSb=b(logpϕ,ψ(y=kxq))b其中 b b b表示在查询特征中的位置。

为了避免图像特征在融合过程中丢失一些关键信息,引入了分类器,用于对查询图像进行分类,目的是保留视觉语义融合过程中的关键特征。使用一个带有 s o f t m a x softmax softmax 1 × 1 1 × 1 1×1卷积层 W \mathcal{W} W,在所有可用的训练类上的对每个查询样本进行分类,这为所有训练类上的查询特征 x q x_q xq中的每个位置生成一个像素级分布:

p ϕ , W ( y = l ∣ W x q ) = e x p ( ( W x q ) l ) ∑ j e x p ( ( W x q ) j ) p_{\phi,\mathcal{W}}(y=l|\mathcal{W}x_q)=\frac{exp((\mathcal{Wx_q})_l)}{\sum_jexp((\mathcal{Wx_q})_j)} pϕ,W(y=lWxq)=jexp((Wxq)j)exp((Wxq)l)
其中 W \mathcal{W} W是卷积层的参数, l l l表示 x q x_q xq所有训练类中的一个标签。
分类的损失定义为:
L c = ∑ b L c b = ∑ b ( − l o g p ϕ , w ( y = l ∣ W x q ) ) b \mathcal{L}_c = \sum_b\mathcal{L}_c^b=\sum_b(-log p_{\phi,\mathcal{w}}(y=l|\mathcal{W}x_q))^b Lc=bLcb=b(logpϕ,w(y=lWxq))b
其中 b b b 表示分类器输出的特征中的位置。 最后,总损失定义为 L = L s + L c \mathcal{L} = \mathcal{L}_s+\mathcal{L}_c L=Ls+Lc.KAN采用端到端的方式进行训练,并最大限度地减少损失 L \mathcal{L} L

流程

在这里插入图片描述

实验效果

从图 2 可以看出,在引入知识图谱之前,类的表示相当混乱,没有清晰的界限。 在增强知识图谱的语义特征后,我们观察到特征空间中的清晰解开,每个类之间具有明显的边界。
在这里插入图片描述

扫描二维码关注公众号,回复: 13263884 查看本文章

自己的理解

在这里插入图片描述
有几个问题没搞懂,邮件联系了作者,作者回复我了

问题1:原文等式(4)为什么有个负号

我当时理解成了原型网络里的距离函数了,认为

如果有负号,而想要损失函数下降,这要求log变大,就是p变大,也就是等式(3)的分子变大,就是使得提取的查询特征和标签对应原型的特征距离变大

作者在邮件中回答我:

相似度我们使用的是余弦,所以我们需要使得提取的查询特征和标签对应原型的余弦相似度变大。

关于相似余弦度可以看这篇博客 余弦距离,欧式距离,马氏距离之间的关系,也就是两个向量如果越相似,则他们的相似余弦度越接近1
在这里插入图片描述
两个矩阵求cosine np.cos

问题2:原文等式(4)(5)为什么有个指数b

作者一开始回复我

A:因为是像素级的损失函数,b指代的为每个像素点。一张图片有很多个像素点,所以最终的损失值应该为所有像素点叠加得来。

但是我没有弄懂,又追问了作者,作者回复我

因为图片特征提取后,没有采用全连接层(具体原因在introduction里有介绍),所以经过CNN得出来的特征为51266,然后经过余弦相似度计算后得到 6 ∗ 6 6* 6 66的特征,公式中的b即为这6*6个点中的任意点,叠加后得到最终的损失。

问题3:为什么 C l a s s i f i c a t i o n Classification Classification 1 × 1 1\times1 1×1卷积层而不是用预训练好的resNet模型等

本文目的是为了提出一个思路,不是为了与其他方法拼结果,所以没采用训练好的resnet。如果采用预训练的,结果肯定会更好,但是违背了初衷,本文目的只为证明引入知识图谱是一种有效的方法。

源代码

链接:https://pan.baidu.com/s/1Zm6NmVKVH-oOovYpT884mg
提取码:dq6i

创新点

  • 类的语义特征通过图卷积网络编码到知识图谱
  • 将零样本中用到语义信息的方法通过知识图谱的方式放到了小样本中
  • 使用像素级交叉熵,使查询嵌入的位置正确分类

猜你喜欢

转载自blog.csdn.net/qq_37252519/article/details/119146341