GIT: Universidade de Stanford propõe um método de impulso invariável para transformações complexas | ICLR 2022

O artigo estuda a invariância de transformações complexas em conjuntos de dados de cauda longa e descobre que a invariância depende em grande parte do número de imagens na categoria e, de fato, o classificador não pode transferir a invariância aprendida da classe grande para a classe média pequena. Para este fim, o artigo propõe um modelo de geração GIT, que aprende transformações complexas que são independentes de classe do conjunto de dados, de modo a aprimorar efetivamente pequenas turmas durante o treinamento, e o efeito geral é bom

. Fonte: Xiaofei's Algorithm Engineering Notes Official Conta

Pergunta: As redes profundas transferem invariâncias entre classes?

Introdução


  Uma boa generalização requer a capacidade do modelo de ignorar detalhes irrelevantes, como o classificador deve responder se a imagem tem como alvo um gato ou um cachorro, não o plano de fundo ou as condições de iluminação. Em outras palavras, a habilidade de generalização precisa incluir invariância para transformações que são complexas, mas não afetam o resultado previsto. Dadas imagens suficientemente diferentes, como o conjunto de dados de treinamento contendo imagens de gatos e cães em um grande número de contextos diferentes, as redes neurais profundas podem realmente aprender a invariância. Mas se todas as imagens de treinamento da classe do cachorro estiverem no fundo da grama, o classificador provavelmente julgará mal o cachorro no fundo da casa como um gato, o que geralmente é um problema com conjuntos de dados desequilibrados.
Desequilíbrio de classe é comum na prática, e muitos conjuntos de dados do mundo real seguem distribuições de cauda longa com muitas imagens para todas, exceto algumas classes de cabeça, e poucas imagens para cada uma das classes de cauda restantes. Portanto, mesmo que a quantidade total de imagens em um conjunto de dados de cauda longa seja grande, pode ser difícil para o classificador aprender a invariância das classes de cauda. Embora o aumento de dados comumente usado possa resolver esse problema aumentando o número e a diversidade de imagens na classe de cauda, ​​essa estratégia não pode ser usada para imitar transformações complexas, como alterar o plano de fundo das imagens. É importante notar que muitas transformações complexas, como mudanças de iluminação, são independentes de classe e podem ser aplicadas da mesma forma a qualquer classe de imagens. Idealmente, um modelo treinado deve ser capaz de converter automaticamente esses invariantes em invariantes agnósticos de classe compatíveis com as previsões de classe de cauda.
O artigo observa a capacidade do classificador de transferir a invariância aprendida entre classes por meio de experimentos, e constata-se a partir dos resultados que, mesmo após balancear estratégias como oversampling, a rede neural transfere a invariância aprendida entre diferentes classes. Por exemplo, em um conjunto de dados de cauda longa em que cada imagem é girada uniformemente de forma aleatória, o classificador tende a ser invariante de rotação para imagens da classe head, mas não invariante de rotação para imagens da classe tail.
Para este fim, o artigo propõe um método simples para transferir invariância entre classes de forma mais eficiente. Primeiro treinamos um modelo generativo condicionado à entrada, mas independente de classe, que captura transformações complexas do conjunto de dados, ocultando informações de classe para incentivar a transferência de transformação entre as classes. Esse modelo generativo é então usado para transformar a entrada de treinamento, semelhante ao aumento de dados de aprendizado para treinar um classificador. O artigo prova por meio de experimentos que, como a invariância da classe de cauda é significativamente melhorada, o classificador geral é mais invariável a transformações complexas, resultando em melhor precisão do teste.

Como medir a transferência de invariância em conjuntos de dados com desbalanceamento de classe


  O artigo primeiro apresenta a invariância em cenários desequilibrados, depois define um indicador para medir a invariância e, finalmente, analisa a relação entre a invariância e o tamanho da categoria.

Configuração:Classificação,Desequilíbrio,e Invariâncias

  definir entrada ( x , S ) (x,y) , etiqueta S S pertence a { 1 , , C } \{1,\cdots,C\} C C é o número de categorias. Defina os pesos do modelo treinado dentro dentro , para prever probabilidades condicionais P ~ dentro ( S = j x ) \tilde{P}_w(y=j|x) , o classificador selecionará a classe com maior probabilidade j j como saída. determinado conjunto de treinamento { ( x ( eu ) , S ( eu ) ) } eu = 1 N P t r a i n \{(x^{(i)}, y^{(i)})\}^N_{i=1}\sim \mathbb{P}_{train} ,通过经验风险最小化(ERM)来最小化训练样本的平均损失。但在不平衡场景下,由于 { y ( i ) } \{y^{(i)}\} 的分布不是均匀的,导致ERM在少数类别上表现不佳。
在现实场景中,最理想的是模型在所有类别上都表现得不错。为此,论文采用类别平衡的指标来评价分类器,相当于测试分布 P t e s t \mathbb{P}_{test} y y 上是均匀的。
为了分析不变性,论文假设 x x 的复杂变换分布为 T ( x ) T(\cdot|x) 。对于不影响标签的复杂变换,论文希望分类器是不变的,即预测的概率不会改变:

Measuring Learned Invariacnes

  为了度量分类器学习不变性的程度,论文定义了原输入和变换输入之间的期望KL散度(eKLD):

  这是一个非负数,eKLD越低代表不变性程度就越高,对 T T 完全不变的分类器的eKLD为0。如果有办法采样 x T ( x ) x^{'}\sim T(\cdot|x) ,就能计算训练后的分类器的eKLD。此外,为了研究不变性与类图片数量的关系,可以通过分别计算类特定的eKLD进行分析,即将公式2的 x x 限定为类别 j j 所属。
计算eKLD的难点在于复杂变化分布 T T 的获取。对于大多数现实世界的数据集而言,其复杂变化分布是不可知的。为此,论文通过选定复杂分布来生成数据集,如RotMNIST数据集。与数据增强不同,这种生成方式是通过变换对数据集进行扩充,而不是在训练过程对同一图片应用多个随机采样的变换。
论文以Kuzushiji-49作为基础,用三种不同的复杂变换生成了三个不同的数据集:图片旋转(K49-ROT-LT)、不同背景强度(K49-BG-LT)和图像膨胀或侵蚀(K49-DIL-LT)。为了使数据集具有长尾分布(LT),先从大到小随机选择类别,然后有选择地减少类别的图片数直到数量分布符合参数为2.0的Zipf定律,同时强制最少的类为5张图片。重复以上操作30次,构造30个不同的长尾数据集。每个长尾数据集有7864张图片,最多的类有4828张图片,最小的类有5张图片,而测试集则保持原先的不变。

  训练方面,采用标准ERM和CE+DRS两种方法,其中CE+DRS基于交叉熵损失进行延迟的类平衡重采样。DRS在开始阶段跟ERM一样随机采样,随后再切换为类平衡采样进行训练。论文为每个训练集进行两种分类器的训练,随后计算每个分类器每个类别的eKLD指标。结果如图1所示,可以看到两个现象:

  • 在不同变化数据集上,不变性随着类图片数减少都降低了。这表明虽然复杂变换是类无关的,但在不平衡数据集上,模型无法在类之间传递学习到的不变性。
  • 对于图片数量相同的类,使用CE+DRS训练的分类器往往会有较低的eKLD,即更好的不变性。但从曲线上看,DRS仍有较大的提升空间,还没达到类别之间一致的不变性。

Trasnferring Invariances with Generative Models


  从前面的分析可以看到,长尾数据集的尾部类对复杂变换的不变性较差。下面将介绍如何通过生成式不变性变换(GIT)来显式学习数据集中的复杂变换分布 T ( x ) T(\cdot|x) ,进而在类间转移不变性。

Learning Nuisance Transformations from Data

  如果有数据集实际相关的复杂变换的方法,可以直接将其用作数据增强来加强所有类的不变性,但在实践中很少出现这种情况。于是论文提出GIT,通过训练input conditioned的生成模型 T ~ ( x ) \tilde{T}(\cdot|x) 来近似真实的复杂变换分布 T ( x ) T(\cdot|x)

  论文参考了多模态图像转换模型MUNIT来构造生成模型,该类模型能够从数据中学习到多种复杂变换,然后对输入进行变换生成不同的输出。论文对MUNIT进行了少量修改,使其能够学习单数据集图片之间的变换,而不是两个不同域数据集之间的变换。从图2的生成结果来看,生成模型能够很好地捕捉数据集中的复杂变换,即使是尾部类也有不错的效果。需要注意的是,MUNIT是非必须的,也可以尝试其它可能更好的方法。
在训练好生成模型后,使用GIT作为真实复杂变换的代理来为分类器进行数据增强,希望能够提高尾部类对复杂变换的不变性。给定训练输入 { ( x ( i ) , y ( i ) ) } i = 1 B \{(x^{(i)}, y^{(i)})\}^{|B|}_{i=1} ,变换输入 x ~ ( i ) T ~ ( x ( i ) ) \tilde{x}^{(i)}\gets \tilde{T}(\cdot|x^{(i)}) ,保持标签不变。这样的变换能够提高分类器在训练期间的输入多样性,特别是对于尾部类。需要注意的是,batch可以搭配任意的采样方法(Batch Sampler),比如类平衡采样器。此外,还可以有选择地进行增强,避免由于生成模型的缺陷损害性能的可能性,比如对数量足够且不变性已经很好的头部类不进行增强。

  在训练中,论文设置阈值 K K ,仅图片数量少于 K K 的类进行数据增强。此外,仅对每个batch的 p p 比例进行增强。 p p 一般取0.5,而 K K 根据数据集可以设为20-500,整体逻辑如算法1所示。

GIT Improves Invariance on Smaller Classes

  论文基于算法1进行了实验,将Batch Sampler设为延迟重采样(DRS),Update Classifier使用交叉熵梯度更新,整体模型标记为 C E + D R S + G I T ( a l l c l a s s e s ) CE+DRS+GIT(all classes) 。all classes表示禁用阈值 K K ,仅对K49数据集使用。作为对比,Oracle则是用于构造生成数据集的真实变换。从图3的对比结果可以看到,GIT能够有效地增强尾部类的不变性,但同时也损害了图片充裕的头部类的不变性,这表明了阈值 K K 的必要性。

Experiment


  不同训练策略搭配GIT的效果对比。

  在GTSRB和CIFAR数据集上的变换输出。

  CIFAR-10上每个类的准确率。

  对比实验,包括阈值 K K 对性能的影响,GTSRB-LT, CIFAR-10 LT和CIFAR-100 LT分别取25、500和100。这里的最好性能貌似都比RandAugment差点,有可能是因为论文还没对实验进行调参,而是直接复用了RandAugment的实验参数。这里比较好奇的是,如果在训练生成模型的时候加上RandAugment,说不定性能会更好。

Conclusion


  O artigo estuda a invariância de transformações complexas em conjuntos de dados de cauda longa e descobre que a invariância depende em grande parte do número de imagens na categoria e, de fato, o classificador não pode transferir a invariância aprendida da classe grande para a classe média pequena. Para este fim, o artigo propõe um modelo de geração GIT, que aprende transformações complexas independentes de classe a partir do conjunto de dados, de modo a aprimorar efetivamente pequenas classes durante o treinamento, e o efeito geral é bom.



Se este artigo for útil para você, dê um like ou assista
. Para mais informações, preste atenção na conta pública do WeChat [Notas de engenharia de algoritmo de Xiaofei]

おすすめ

転載: juejin.im/post/7121571917164707854