NIPS2019《Cross Attention Network for Few-shot Classification》

在这里插入图片描述
发表于NIPS2019!!!
论文链接:https://proceedings.neurips.cc/paper/2019/file/01894d6f048493d2cacde3c579c315a3-Paper.pdf
代码链接:https://github.com/blue-blue272/fewshot-CAN

1. 动机

在这里插入图片描述
虽然有希望,但很少有人对所提取的特征的可识别性给予足够的重视。它们通常独立地从支持类和无标签查询样本中提取特征,因此特征不够有区别。一方面,支持/查询集中的测试图像来自不可见的类,因此它们的特性几乎不能用于目标物体。具体来说,对于包含多个目标的测试图像,提取的特征可以关注训练集中有大量标记样本的已见类中的目标,而忽略未见类中的目标。如上图1©和(d),来自测试类窗帘的两张图像,提取的特征只捕获与训练类相关的目标的信息,如图1 (a)和(b)中的人或椅子。另一方面,低数据问题使得每个测试类的特征不能代表真正的类分布,因为它是由非常少的label支持样本获得的。总之,独立特征表示在小样本分类中可能会失败。

2. 贡献

在本工作中,提出了一种新的**交叉注意网络(CAN)**来提高小样本分类的特征可鉴别性。
1)首先,引入交叉注意模块(CAM)来解决不可见类问题。交叉注意的想法是受人类少样本学习行为的启发。为了从一个未被发现的类别中识别出一个样本,人类倾向于首先在标记和未标记的样本对中定位最相关的区域。类似地,给定一个类特征图和一个查询样例特征图,CAM为每个特征生成一个交叉注意图来突出显示目标对象。为了达到这一目的,采用了相关估计和元融合方法。这样可以使测试样本中的目标对象获得注意,交叉注意图加权的特征具有更强的判别性。如图1 (e)所示,利用CAM提取的特征可以对目标物体幕区域进行粗略的定位。
2)其次,我们引入了一个直推推理算法,利用整个无label查询集来缓解低数据问题。该算法迭代预测查询样本的标签,并选择伪标签查询样本来扩大支持集。每个类支持样本越多,得到的类特征就越有代表性,从而缓解了低数据问题。

3. 方法

3.1 问题定义

少样本分类通常包括一个训练集、一个支持集和一个查询集。训练集包含大量的类和标注的样本。少数标注样本的支持集和无标注样本的查询集共享同一个标注空间,而标注空间与训练集的标注空间是不相连的。少样本分类的目的是对给定训练集和支持集的无标记查询样本进行分类。如果支持集由C类和每个类的K个标记样本组成,则目标少样本问题称为C-way K-shot。
根据已有经验,本文也采用episode训练机制,该机制已被证明是一种有效的少样本学习方法。训练中使用的episode模拟了测试中的设置。每个episode是由随机抽样 C C C类和每个类 K K K个标记样本作为支持组 S = { ( x a s , y a s ) } a = 1 n s ( n s = C × K ) \mathcal{S} = \{ (x^s_a, y^s_a)\}^{n_s}_{a=1} (n_s = C \times K) S={ (xas,yas)}a=1ns(ns=C×K) C C C类中剩余样本的一小部分作为查询集 Q = { ( x b q , y b q ) } b = 1 n q \mathcal{Q} = \{ (x^q_b, y^q_b)\}^{n_q}_{b=1} Q={ (xbq,ybq)}b=1nq组成。我们将 S k \mathcal{S}^k Sk表示为第 k k k类的支持子集。如何表示每个支持类 S k \mathcal{S}^k Sk和查询样本 x b q x^q_b xbq,并度量它们之间的相似性是少样本分类的关键问题。

3.2 Cross Attention Module

在本工作中,我们通过度量学习为每对支持类和查询样本获得适当的特征表示。本文提出了交叉注意模块(Cross Attention Module, CAM),该模块可以对类特征和查询特征之间的语义相关性进行建模,从而引起对目标物体的注意,有利于后续的匹配。
在这里插入图片描述
CAM如图上图(a)所示。类特征映射 P k ∈ R c × h × w P^k \in \mathbb{R}^{c \times h \times w} PkRc×h×w是从支持样本 S k ( k ∈ { 1 , 2 , ⋯   , C } ) \mathcal{S}^k (k \in \{ 1, 2, \cdots, C\}) Sk(k{ 1,2,,C})中提取,而查询特征映射 Q b ∈ R c × h × w Q^b \in \mathbb{R}^{c \times h \times w} QbRc×h×w是从查询样本 x b q ( b ∈ { 1 , 2 , ⋯   , n q } ) x^q_b (b \in \{ 1, 2, \cdots, n_q\}) xbq(b{ 1,2,,nq})中提取。其中 c c c h h h w w w分别为特征图的通道数、高度、宽度。CAM为 P k ( Q b ) P^k (Q^b) Pk(Qb)生成交叉注意图 A p ( A q ) A^p (A^q) Ap(Aq),然后利用 A p ( A q ) A^p (A^q) Ap(Aq)对特征图进行加权,实现更具有判别性的特征表示 P ˉ b k ( Q ˉ k b ) \bar{P}^k_b (\bar{Q}^b_k) Pˉbk(

猜你喜欢

转载自blog.csdn.net/weixin_43994864/article/details/123349370