ICLR2022《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》

在这里插入图片描述
论文链接:https://arxiv.org/pdf/2109.06165.pdf
代码链接:https://github.com/CDTrans/CDTrans

1. 动机

无监督域适应(Unsupervised domain adaptive, UDA)是一种将知识从有标记的源域转移到不同的无标记目标域的方法。现有的UDA方法大多集中于学习领域不变的特征表示,无论是从领域级别还是类别级别,使用基于卷积神经网络(CNNs)的框架。基于类别级别的UDA的一个基本问题是目标域中的样本会产生伪标签,这些伪标签噪声太大,不利于准确的域对齐,不可避免地影响UDA的性能。随着Transformer在各种任务中的成功应用,发现Transformer中的交叉注意对噪声输入对具有较好的鲁棒性,从而实现更好的特征对齐,因此本文采用Transformer来完成具有挑战性的UDA任务。

2. 方法

提出一个triple-branch transformer框架(CDTrans),利用其对噪声标记数据的鲁棒性和强大的特征对齐能力;为了生成高质量的伪标签,提出了一种双向中心感知的标记方法,提高了CDTrans环境下的最终性能。

  • The Cross Attention in Transformer
    自注意模块的目的是强调输入图像的小块之间的关系:
    在这里插入图片描述
    交叉注意模块来源于自注意模块。不同的是,交叉注意的输入是一对图像,即 I s I_s Is I t I_t It。它的query和key/value分别来自 I s I_s Is I t I_t It的patch。交叉注意模块的计算如下:
    在这里插入图片描述
    在这里插入图片描述
    如图1a所示,交叉注意模块会给假阳性对中的不同patch分配一个低权重,从而在一定程度上减弱了不同patch对最终性能的负面影响;如图1b所示,x轴表示训练数据中的假阳性对率,y轴表示不同方法在UDA任务中的表现,红色曲线表示通过交叉注意模块对对齐得到的结果,绿色曲线表示没有交叉注意的结果,即直接用对应源数据的标签对目标数据进行训练,蓝色曲线是为了从训练数据中去除假阳性对,只使用真阳性对训练交叉注意,在没有噪声数据的情况下,蓝色曲线可以被认为是我们方法的上限。可以看出,红色曲线比绿色曲线的性能好得多,说明交叉注意模块对噪声具有鲁棒性。
  • Two-Way Center-Aware Pseudo Labeling
    1)two-way
    为了构建交叉注意模块的训练对,一种直观的方法是,对源域中的每一幅图像,从目标域中设法找到最相似的图像。所选pair的集合 P S P_S PS为:
    在这里插入图片描述
    其中S, T分别是源数据和目标数据。 d ( f s , f k ) d(f_s, f_k) d(fs,fk)表示图像 i i i与图像 j j j之间的特征距离。该策略的优点是充分利用源数据,缺点是只涉及目标数据的一部分。为了从目标数据中消除这种训练偏差,我们从相反的方向引入更多对 P T P_T PT,由源域中所有的目标数据和它们对应的最相似的图像组成:
    在这里插入图片描述
    因此最终的集合 P P P是两个集合的并集,即 P = P S ∪ P T P = {P_S \cup P_T} P=PSPT,使得训练对包含了所有的源数据和目标数据。
    2)Center-Aware Filtering
    P P P中的对是基于两个域图像的特征相似度构建的,因此对的伪标签的准确性高度依赖于特征相似度。在论文《Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation》的启发下发现源数据的预训练模型也有助于进一步提高准确性。首先,作者将所有的目标数据通过预训练的模型发送出去,从分类器得到它们在源类别上的概率分布δ;与论文相似,这些分布可以通过加权k-means聚类来计算目标域中各个类别的初始中心:
    在这里插入图片描述
    其中, σ t k \sigma^k_t σtk表示图像 t t t出现在类别 k k k上的概率。目标数据的伪标签可以通过最近邻分类器产生:
    在这里插入图片描述
    其中 t ∈ T t \in T tT d ( i , j ) d(i, j) d(i,j)是特征 i i i j j j的距离。基于伪标签,可以计算出新的中心:
    在这里插入图片描述
    对于每一对,如果目标图像的伪标签与源图像的标签一致,则保留这一对进行训练,否则将其作为噪声丢弃。
  • CDTrans: Cross-domain Transformer
    在这里插入图片描述
    提出的跨域变压器(CDTrans)的框架如图2所示,它由三个权重共享变压器组成。权重共享分支有三个数据流和约束。
    框架的输入是我们上面提到的标记方法中选择的对。这三个分支分别命名为源分支、目标分支、源-目标分支。如图2所示,将输入对中的源图像和目标图像分别发送到源支路和目标支路。在这两个分支中,自注意模块用于学习特定于域的表示。并利用softmax交叉熵损失训练分类。值得注意的是,由于两个图像有相同的标签,这三个分支共享同一个分类器。交叉注意模块在源目标分支中导入。源-目标分支的输入来自其他两个分支。 在 第 n 层 中 , 交 叉 注 意 模 块 的 q u e r y 来 自 源 分 支 的 第 n 层 q u e r y , 而 k e y 和 v a l u e 来 自 目 标 分 支 的 k e y 和 v a l u e ( 为 何 不 用 增 强 后 的 目 标 域 和 源 域 特 征 作 为 q u e r y 和 k e y , v a l u e ? ? 这 里 让 人 有 点 疑 惑 \textcolor{red}{在第n层中,交叉注意模块的query来自源分支的第n层query,\\ 而key和value来自目标分支的key和value(为何不用增强后的目\\标域和源域特征作为query和key,value??\\这里让人有点疑惑} nquerynquerykeyvaluekeyvaluequerykeyvalue??。然后交叉注意模块输出与第N-1层的输出对齐的特征。
    由于交叉注意模块的存在,源-目标分支的特征不仅使两个域的分布保持一致,而且对输入对中的噪声具有鲁棒性。因此,我们使用源-目标分支的输出来指导目标分支的训练。其中源-目标分支表示为teacher,目标分支表示为student。我们将源-目标分支中的分类器的概率分布视为一个软标签,可用于通过蒸馏损失进一步监督目标分支:
    在这里插入图片描述
    其中, q k q_k qk p k p_k pk分别是类别 k k k从源-目标分支和目标分支得到的概率。在推理过程中,只使用目标分支。输入是测试数据的图像,只触发目标数据流,如图2中的蓝线。分类器的输出被用作最终的预测标签

3. 部分实验结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4. 结论

1)本文通过在Transformer中引入交叉注意模块来解决无监督域自适应问题。并提出了一种新的网络结构CDTrans,它是一个纯Transformer结构,有三个分支;
2)提出了一种使用双向中心感知标记方法生成高质量的伪标签。使用生成的高质量伪标签训练CDTrans可以产生一个健壮的解决方案,并在四个流行的UDA数据集上实现最先进的结果,大大超过以前的方法。

猜你喜欢

转载自blog.csdn.net/weixin_43994864/article/details/123324038