Domain Generalization (Domain Generalization) related knowledge learning


Author:S-Lab, NTU + HKUST


1. Overview of Domain Generalization

1) Domain definition

\quad A Domain consists of data sampled from a distribution. In the following definition, an S is a domain.
insert image description here
\quadSome mathematical definitions are as follows:
insert image description here

2) Definition of Domain Generalization (DG)

\quadAs shown in the picture below: the goal of domain generalization is to learn a model from one or several different but related domains (training sets), and get a good generalization on the unseen test domain. (In the definition of DG, "different but related" is the key point, which means that although domains are different, they must be related. The categories contained in each domain are actually the same .)
insert image description here
\quadThe definition of DG is as follows, the purpose is to obtain a generalized prediction function h after training on M source domains, so that its error on the unseen test domain is the smallest.
insert image description here
insert image description here
\quadThat is to say, in each domain of DG, the categories contained are the same, but their joint distributions are different, that is, the sampling distribution PXY P_{XY} of each domain S is obtainedPXYdifferent. Intuitively reflected in the figure below, each domain contains the same category, but the style of expression is different, such as sketch, cartoon and art painting, etc., each style is called a domain.
insert image description here

3) The difference between some related fields and DG

insert image description here

  • Transfer learning : Transfer learning trains a model on a source task with the aim of improving the model 's performance on a different but related target domain/task. Pre-training-fine-tuning is a common strategy for transfer learning, where the tasks of the source domain and the target domain are different, and the target domain is visible during training . In DG, the target domain is not accessible, and the training and testing tasks are often the same, but their distributions are different .
  • Domain adaptation : The goal of DA is to utilize the existing training source domain(s) to maximize the performance on a given target domain. The difference between DA and DG is that DA has access to target domain data, while DG cannot see this data during training . Domain adaptation can be seen as a variant of transfer learning. For a source-trained model, DA hopes to use sparse (a small amount of labeled) or unlabelled data from the Target Domain to correct or fine-tune the model. This fine-tuning process is called Adaptation, which is equivalent to modifying (or, finetuning) the model that has been trained on the Source Domain using some data in the Target Domain , so that the model can adapt to the new field and overcome the DS problem.
  • Meta learning : Meta learning is a general learning strategy that can improve the performance of DG by simulating the meta-train and meta-test tasks of the training domain.
  • Zero-shot learning : The categories of zero-shot are different, and the categories of DG are the same . Zero-shot does not pay attention to the difference of the domain, but the difference of the category, as long as the class during the test is unseen. That is to say, zero-shot learning (zero-shot learning) learns that the categories encountered at test time are completely new, categories that have not been seen during training. The DG (domain generalization) class is the same during training and testing, but the domain of the same category is different during testing, which is a domain that has not been seen during training. In addition, there is few-shot learning , which focuses on classification with a limited number of samples. During training, it will train on a large number of categories with a small number of samples. When testing, for new categories, only need Recognition of this category can be achieved by fine-tuning on a few samples.

4) Domain generalization method

\quad DG is mainly divided into three categories of methods: data manipulation, representation learning, and learning strategies. As shown below.
insert image description here
\quad Data manipulation : refers to enhancing the training data by enhancing and changing the data. This category includes two parts: data augmentation and data generation. Data augmentation increases samples based on enhancement, randomization, and transformation of input data, and data generation generates different samples through VAE or GAN to help generalization.

\quad Representation learning : The purpose is to learn domain-invariant representation learning, so that the model can be well adapted to different fields . Domain invariant feature learning mainly includes four parts : kernel method, explicit feature alignment, domain confrontation training, and Invariant Risk Minimization (IRM). In addition, there is a separate method called feature decoupling , which is consistent with the goal of domain-invariant feature learning. It is generalized by learning domain-invariant representations, but the learning methods are inconsistent, trying to decompose features into domain-shared or domain-invariant features. Domain-specific parts for better generalization. (But the way of domain invariant representation learning may not be suitable for GeNeRF.)

\quad Learning strategy : Introducing mature learning models in machine learning into multi-domain training makes the model more generalizable. This part mainly includes methods based on ensemble learning and meta-learning . In addition, there are methods such as self-supervised learning to learn generalized representations by constructing pretext tasks, and metric learning.

The following is mainly explained from the mainstream category of representation learning:

Indicates learning

\quadFrom the above definition of DG, it can be seen that the purpose of domain generalization is to learn a generalized prediction function h, here the function h is decoupled as: h = f ∘ gh=f \circ gh=fg , where g is a representation learning function and f is a classification function. It can also be seen from this definition thatcurrent domain generalization research is mainly used for classification tasks.

\quadIn domain generalization, the objective of representation learning can be defined as follows:
insert image description here
\quadwhere lreg l_{reg}lregis a regularization term. Many methods are designed from the perspective of better learning the feature extraction function g and the corresponding regularization function g. According to different learning concepts, representation learning is divided into two categories: Domain-invariant feature learning and Feature disentanglement .

Domain Invariant Representation Learning

\quad There is work to theoretically prove (Analysis of representations for domain adaptation) that if a feature representation is invariant to different domains, then the representation is general and transferable to different domains. Based on this theory, a large number of domain adaptive algorithms have been proposed. Similarly, for domain generalization, the goal is to reduce the representation difference between multiple source domains in a specific feature space to be domain invariant (equivalent to learning an intersection), so that the learned model can generalize to unseen domains. ability to transform.

\quadFrom this perspective, domain invariant representation learning falls into four categories: kernel-based methods, domain adversarial learning, explicit feature alignment, and invariant risk minimization .

① Kernel-based methods

\quadThe basic idea of ​​Kernel-based machine learning is to map linearly inseparable data in low-dimensional space to high-dimensional space, so that data becomes linearly separable in high-dimensional space. This approach helps to solve nonlinear problems in some machine learning tasks, such as classification, clustering, regression, etc.

\quadThe kernel function is the core of Kernel-based machine learning. The kernel function is usually a function to calculate the distance. It can be used to map data from a low-dimensional space to a high-dimensional space. It does not need to calculate the coordinates of the data in the high-dimensional space, but only needs to calculate the distance between all pairs of samples in the feature space. Inner product between. . For example, RBF (Radial Basis Function) is one of the commonly used kernel functions, which is defined as:
insert image description here
\quadAmong them, x and x' are data samples, and γ is an adjustable parameter used to control the shape of the kernel function. One of the most representative kernel-based methods is the Support Vector Machine (SVM).

\quad There are many kernel-based algorithms in DG, which will represent the learning function ggg is implemented as some feature map (feature map)φ ( ⋅ ) φ(⋅)φ ( ) , these feature maps are easy to use kernel functionk ( ⋅ , ⋅ ) k(⋅,⋅)k(,) (such as RBF kernel and Laplacian kernel) to calculate (that is, to calculate distance). In general,kernel-based methods achieve domain generalization by mapping data into a high-dimensional feature space and computing similarity or distance in this space.

\quadThere are two main aspects of the application of kernel function-based methods in domain generalization ( a review of generalized deep learning ):

  • Domain Adaptation: In domain adaptation problems, training data and test data come from different data distributions (domains), and the model needs to learn in the domain of training data and then generalize in the domain of test data . Because in the DA setting, part of the data in the test domain can be seen, the domain adaptive method based on the kernel function introduces domain adaptive constraints in the kernel function, such as the maximum mean difference (Maximum Mean Discrepancy, MMD), to Reduce the domain difference between the training data and the test data (close the distance between the two distributions), thereby improving the performance of the model on the test data.
  • Transfer Learning: In cross-domain learning problems, training data and test data may come from different domains, but there is some shared information or knowledge that can be used for generalization. The kernel function-based cross-domain learning method considers the similarity between the source domain and the target domain in the kernel function, so that the knowledge learned in the source domain can help improve the generalization performance in the target domain.

\quad Learn about Maximum Mean Discrepancy (MMD) here. MMD is a non-parametric measure of the distance between two probability distributions . It can be used in the kernel method to quantify the difference between two sample sets. The basic idea is to compare the distance between two distributions by mapping the data to the feature mapping function in the feature space .

\quad MMD works by computing the difference between the sample means of two probability distributions . Specifically, for two probability distributions PPP andQQQ , MMD computes the distance between them by comparing their representations in the feature space. Therefore, a feature mapping function is needed to map the sample data from the original space to the feature space, and then calculate the performance difference of the two probability distributions in the feature space. This process can be implemented using a kernel function.

\quadMMD can be used to measure the difference between the original data distribution and the generated data distribution, so it has been widely used in Generative Adversarial Networks (GAN). When training the GAN model, MMD can be used to measure the difference between generated samples and real samples, helping the GAN model to better generate real samples. In addition to GAN, MMD can also be applied to various machine learning tasks, including classification, regression, clustering, and anomaly detection.

② Domain adversarial learning

Domain-adversarial training is widely used to learn domain-invariant features. Such as domain adversarial neural network (DANN) for DA, which trains generator and discriminator. The discriminator is trained to distinguish domains, while the generator is trained to trick the discriminator into learning domain-invariant feature representations.

③ explicit feature alignment (explicit feature alignment)

\quad First, let's understand the feature distribution alignment. Feature distribution alignment aims to make the feature distributions of different sample points the same or similar through specific methods . Its goal is to make all samples share the same feature distribution, thereby improving the performance of machine learning algorithms, especially in the case of cross-domain (not cross-domain can also be used). Here are some ways to achieve feature distribution alignment:

  • Use adaptive methods: Adaptive methods utilize some adaptive techniques to map features into an implicit space where the original feature distribution is embedded in the same space as the target distribution, which can better align the feature distributions of different samples .
  • Methods based on Maximum Mean Divergence (MMD): MMD is a measure of the distance between two distributions, with which the maximum mean divergence can be calculated between two sample points so that the feature distributions are as similar as possible . This method can be implemented by using various kernel functions.
  • Unsupervised domain adaptation methods: This method can use unlabeled data to accomplish domain adaptation. These techniques utilize unlabeled test data to generate a suitable domain-adaptive model for training data to extract more suitable features from different domains to achieve feature distribution alignment.

\quadFeature distribution alignment is an important research area in domain adaptation. Achieving feature distribution alignment can be achieved through unsupervised learning methods or supervised learning methods . The specific method needs to be selected for specific problems.

\quadThe generalization work based on explicit feature alignment is to align features across source domains to learn domain-invariant representations through explicit feature distribution alignment or feature normalization .

\quadThe way to use explicit feature distribution alignment is to achieve feature alignment by explicitly making the feature distributions of multiple source domains as close as possible ( kernel methods and contrastive learning are implicit ). The specific operation includes the following steps:

  • Statistical distribution of features across multiple source domains. Use some distance metric such as KL divergence to measure the difference in feature distribution between source domains.
  • A loss function that minimizes the domain distance is introduced to force the distributions to be as close as possible.

\quad In addition, the method of feature normalization can also be used to normalize the features of multiple source domains so that they have similar statistical properties, so as to achieve the purpose of feature alignment.

④Invariant risk minimization (IRM)

\quadBefore introducing IRM, let's take a look at the classic generalization method: Empirical Risk Minimization (ERM). The empirical risk is the average value of the loss function on the training data set, refer to the previous notes: "Statistical Learning Methods" (Li Hang) - study notes . Empirical risk minimization is to minimize the average loss on the training set.
\quad ERM is the most classic algorithm of generalization and is often regarded as a baseline . Its disadvantage is that it assumes that the testing set and training set are the same distribution, and does not fully consider the issue of domain shift.
insert image description here
\quad In addition, there is another concept that is also very important, that is, spurious features .
\quadWhen doing cross-domain generalization, a problem often arises, that is, there may be some causal factors in the test data that seem to be related to the label but are not actually causal factors, which we call explicit confounding factors, or some seemingly Factors that are not related to labels but can affect predictions are called implicit confounding factors.
\quadAmong them, some of the hidden confounding factors may appear by chance and have nothing to do with the real label. These factors are called spurious features, and they may be strongly correlated with labels in the training data domain, but no longer have such properties in the test data.
\quadFor example, in the cat and dog classification problem, background information may be spurious features that are not related to the label. For example, in the training data, only the background of the cat picture is patterned, while the background of the dog picture is pure black, then the model may learn to use the pattern as a feature of the cat, which will appear in the test data, as long as it is a pattern The picture of the background is classified as a cat, which is not in line with the real situation.
\quadTherefore, dealing with spurious features is an important step in domain generalization, methods include but not limited to feature selection, feature transformation, domain adaptation, etc.

\quad After understanding the above prerequisite knowledge, let's look at Invariant Risk Minimization (IRM).
\quadThe goal of IRM is to learn invariant features across domains, rather than spurious features related to the environment. If the spurious feature is recorded as Xs, the invariant feature is recorded as Xv, that is, P(Y|Xv) is constant, and P(Y|Xs )Variety. The core idea of ​​IRM is to force the model to be insensitive to confounding variables in the process of model training, even if the confounding variables change in different domains, the model can maintain stable predictive ability.
\quadTherefore, IRM is suitable for those domain generalization problems that need to deal with confounding variables in the data. For example, in the field of medical image analysis, since the data sets usually come from multiple hospitals, there may be observational or operational differences among these hospitals. Sexual confounding, such as differences in scanning equipment, camera angles, light and other factors, these confounding factors may have an impact on the generalization ability of the model. In this case, the IRM method can effectively reduce the influence of confounding factors, thereby improving the generalization performance of the model.
\quadHowever, IRM is suitable when multiple environments share the same feature space. If the feature spaces in different environments are different, then IRM may fail and need to be dealt with by other methods.

feature disentanglement

\quad Disentangled representation learning aims to learn a function that maps samples to feature vectors that contain all information about the factors of different variables, with each dimension (or subset of dimensions) containing information about only certain factors.

\quadDisentanglement-based DG methods usually decompose feature representations into understandable combinations/sub-features, where one feature is a domain-shared/invariant feature and the other is a domain-specific feature.

\quadAccording to the choice of network structure and implementation mechanism, disentanglement-based DG can be mainly divided into three categories: multi-component analysis, generative modeling, and causal incentive methods. The following mainly introduces the category of generative modeling.

From the perspective of the data generation process, generative models can be used to disentangle. Such methods attempt to construct sample generation mechanisms from domain-level, sample-level, and label-level. Some works further decompose the input into class-irrelevant features that contain information related to a specific instance [201]. Domain Invariant Variational Autoencoder (DIVA) [124] decomposes features into domain information, category information and other information, which are learned in a VAE framework. Peng et al. [125] disentangled fine-grained domain information and category information under the framework of VAE. VAEs are also used for disentanglement by Qiao et al. [40], who proposed a unified feature disentanglement network (UFDN) that takes data domains and image properties of interest as latent factors to be disentangled. Similarly, Zhang et al. [126] disentangle the semantic and variational parts of the samples.

[40] F. Qiao, L. Zhao, and X. Peng, “Learning to learn single domain generalization,” in CVPR, 2020, pp. 12 556–12 565.
[124] M. Ilse, J. M. Tomczak, C. Louizos, and M. Welling, “Diva: Domain invariant variational autoencoders,” in Proceedings of the Third Conference on Medical Imaging with Deep Learning, 2020.
[125] X. Peng, Z. Huang, X. Sun, and K. Saenko, “Domain agnostic learning with disentangled representations,” in ICML, 2019.
[126] H. Zhang, Y.-F. Zhang, W. Liu, A. Weller, B. Sch¨olkopf, and E. P. Xing, “Towards principled disentanglement for domain
generalization,” in ICML2021 Machine Learning for Data Workshop, 2021.
[201] Y. Wang, H. Li, L.-P. Chau, and A. C. Kot, “Variational disentanglement for domain generalization,” arXiv preprint arXiv:2109.05826, 2021.

Reference: "Generalizing to Unseen Domains: A Survey on Domain Generalization", a review of generalized deep learning , "Xiao Wang Loves Migration" series twenty-eight: a review to give you a comprehensive understanding of the domain generalization of transfer learning (Domain Generalization)


二、【ICLR‘23-notable 5%】Sparse Mixture-of-Experts are Domain Generalizable Learners

Generalization to out-of-distribution (OOD) data is an innate ability of human vision but highly challenging for machine learning models. Domain generalization (DG) is one approach to this problem, which encourages models to be resilient to various distributional changes, such as background , lighting , texture , shape , and geographic/demographic attributes.

From the perspective of representation learning, there are several paradigms for implementing DG: domain alignment, invariant causality prediction, meta-learning, ensemble learning, feature disentanglement. Recent studies show that these methods improve ERM and achieve promising results on large-scale DG datasets. In addition to ERM, there are actually many other methods for domain generalization ( DeepDG ):
insert image description here
At the same time, in various computer vision tasks, the innovation of backbone architecture plays a key role in improving performance and has attracted widespread attention. There are also works that prove that different CNN architectures have different performances on the DG dataset. Inspired by these pioneering works, we conjecture (conjecture, introducing a word that assumes a good idea: conjecture): "backbone architecture design would be promising for DG" . Therefore, to verify such intuition, we evaluate Transformer-based and CNN-based architectures under the same computational overhead. But surprisingly, the effect of ViT-S/16 trained with ERM is better than that of ResNet50 trained with SOTA domain generalization methods (on DomainNet, OfficeHome and VLCS datasets), despite having the same number of parameters They have similar effects on the in-distribution domain.
insert image description here

Based on the algorithmic alignment framework, we theoretically verified this effect. We first show that networks trained with the ERM loss function are more robust to distribution shifts if their structure is more similar to invariant correlation, where similarity is formally measured by an alignment value defined by Xu et al. (2020a). Conversely, a network is less robust if its architecture is consistent with pseudo-correlations. We then investigate the alignment between backbone architectures (i.e., convolution and attention) and the correlation in these datasets, which explain the superior performance of vit-based methods.

To further improve performance, our analysis shows that: To address domain generalization, we should exploit the properties of invariant correlations in vision tasks and design network architectures to be consistent with these properties. This calls for investigations at the intersection of domain generalization and classical computer vision. In domain generalization, it is generally believed that data is composed of a set of attributes, and the distribution shift of data is the distribution shift of these attributes. The latent factorization models for these attributes are almost identical to the generative models for visual attributes in classical computer vision. To capture these diverse properties, we propose a Generalizable Mixture-of-Experts (GMoE), which builds on sparse
mixture-of-experts (sparse MoEs) (Shazeer et al., 2017) and vision transformer (Dosovitskiy et al. , 2021). Sparse MoEs were originally proposed as a key enabler of very large but efficient models. Through theoretical and empirical evidence, this paper demonstrates that MoEs are experts at handling visual properties, leading to better alignment with invariant correlations.

Innovativeness of this article:

  • A Novel View of DG: In contrast to previous works, this paper initiates a formal exploration of the backbone architecture in DG. Based on algorithmic alignment (Xu et al., 2020a), we prove that a network is more robust to distribution shifts if its architecture aligns with the invariant correlation, whereas less robust if its architecture aligns with spurious correlation. The theorems are verified on synthetic and real datasets.
  • A Novel Model for DG: Based on our theoretical analysis, we propose Generalizable Mixture-ofExperts (GMoE) and prove that it enjoys a better alignment than vision transformers. GMoE is built upon sparse mixture-of-experts (Shazeer et al., 2017) and vision transformer (Dosovitskiy et al., 2021), with a theory-guided performance enhancement for DG.
  • Excellent Performance: xxx

When writing an article:
The first innovation is generally an observation angle and analysis that others have never done before, and a conclusion is proved through such observation and analysis.
The second innovation is generally to design a model based on such observations to conclude.
The third is generally model results.

Guess you like

Origin blog.csdn.net/DUDUDUTU/article/details/130782169