[Interpretation of the paper] Graph-based self-supervised learning joint embedding prediction architecture

1. Brief introduction

This paper demonstrates a method for learning highly semantic image representations without relying on hand-crafted data augmentation. The paper introduces the Image-based Joint Embedding Prediction Architecture (I- JEPA ), a non-generative approach for self-supervised learning from images. The idea behind I- JEPA is simple: from a single contextual block, predict the representations of different target blocks in the same image. The core design choice that guides I- JEPA to generate semantic representations is the masking strategy; specifically, (a) predicting a few target patches in the image, (b) sampling a sufficiently large sample target patch (15% of the image- 20%), (c) using sufficiently rich (spatially distributed) context blocks is crucial. Empirically, the paper finds that I- JEPA is highly scalable when combined with a vision transformer . For example, the paper uses 32 A100 GPUs to train a ViT-Huge/16 in 38 hours on ImageNet to achieve strong downstream performance in a wide range of tasks requiring different levels of abstraction, from linear classification to object counting and depth prediction . 

2. Research Background

In computer vision, there are two common methods for self-supervised learning from images: invariance-based methods and generative methods.

Invariance-based pre-training methods optimize encoders to produce similar embeddings for two or more views of the same image, where image views are typically constructed using a set of hand-crafted data augmentations such as random scaling, cropping, and color dithering, and other. These pre-training methods can produce high-semantic-level representations, but they also introduce strong biases that can be detrimental to some downstream tasks, or even pre-training tasks with different data distributions.

Cognitive learning theory suggests that one driving mechanism behind representation learning in biological systems is how an internal model is adapted to predict responses to sensory input. This idea is at the heart of self-supervised generative methods, which remove or corrupt parts of the input and learn to predict what is corrupted. In particular, mask denoising methods learn representations by reconstructing random mask patches from the input at the pixel or token level. Compared with view-invariant methods, the pre-training task of masks requires less prior knowledge and generalizes easily beyond image modalities. However, the resulting representations are often of low semantic level and lack invariance-based pre-training in off-the-shelf evaluations (such as linear detection) and transfer settings with limited supervision for semantic classification tasks. Therefore, a more complex adaptation mechanism (e.g., end-to-end fine-tuning) is required to obtain the full benefits of these methods.

In this work, we explore how to improve the semantic level of self-supervised representations without using additional prior knowledge encoded image transformations. To this end, the paper introduces an Image Joint Embedding Prediction Architecture (I-JEPA). Figure 3 provides an illustration of this approach. The idea behind I-JEPA is to predict missing information in an abstract representation space; e.g., given a context block, predict representations for different target blocks in the same image, where the target representation is computed by a learned target encoder network.

Compared to generative methods that make predictions in pixel/label space, I-JEPA utilizes an abstract prediction target, potentially removing unnecessary pixel-level details, leading the model to learn more semantic features. Another core design choice that guides I-JEPA to generate semantic representations is the proposed multi-block masking strategy. Specifically, the paper demonstrates the importance of using an informative (spatially distributed) contextual patch to predict several target patches (of sufficiently large scale) in an image.

Through extensive empirical evaluation, the paper demonstrates that:

I- JEPA learns powerful off-the-shelf semantic representations without using handcrafted view augmentations (Fig. 1). I-JEPA outperforms pixel reconstruction methods such as MAE on ImageNet-1K linear detection, semi-supervised 1% ImageNet-1K and semantic transfer tasks.

I-JEPA is competitive with view-invariant pre-training methods on semantic tasks and achieves better performance on low-level vision tasks such as object counting and depth prediction. By using a simpler model and less rigid inductive biases, I-JEPA is applicable to a wider set of tasks.

I-JEPA is also scalable and efficient. Pretraining ViT-H/14 on ImageNet takes about 2400 GPU hours, which is 50% faster than ViTB/16 pretrained with iBOT and 140% faster than ViT-L/16 pretrained with MAE. Predictions in representation space significantly reduce the total computation required for self-supervised pre-training.

Self-supervised learning is a method of representation learning in which a system learns to capture the relationships between its inputs. This goal can be easily described using the framework of energy-based models (EBMs), where the goal of self-supervision is to assign high energy to incompatible inputs and low energy to compatible inputs. Many existing generative and non-generative self-supervised learning methods can indeed be transformed in this framework; see Figure 2.

Joint Embedding Architectures. Invariance-based pre-training can be coerced within the framework of EBM using the Joint Embedding Architecture (JEA); see Figure 2a. The joint embedding architecture learns to output similar embeddings for compatible inputs x, y and different embeddings for incompatible inputs. In image-based pre-training, compatible x, y pairs are usually constructed by randomly applying hand-crafted data augmentations to the same input images.

The main challenge of JEA is to represent collapse, where the energy landscape is flat (i.e., the encoder produces a constant output regardless of the input). Over the past few years, several approaches have been investigated to prevent representation collapse, such as contrastive losses that explicitly promote negative example embeddings, non-contrastive losses that minimize the information redundancy of embeddings, and clustering-based methods to maximize the average Embedded entropy. There are also some heuristics that exploit the asymmetric architecture design between the x-encoder and y-encoder to avoid collapse.

Generative Architectures. Refactoring-based self-supervised learning methods can also be coerced within the EBM framework using generative architectures; see Figure 2b. Generative architectures learn to directly reconstruct a signal y from a compatible signal x, using an additional (possibly latent) decoder network of variable z to facilitate the reconstruction. In image-based pre-training, a common approach in computer vision is to use masks to produce compatible x, y pairs, where x is a copy of image y but with some patches masked out. The conditioning variable z then corresponds to a set of (possibly learnable) masks and positional markers that specify the decoder for the image patch to be reconstructed. These architectures do not care about representation collapse as long as the information content of z is lower than the signal y.

Joint-Embedding Predictive Architectures. As shown in Figure 2c, the joint embedding prediction architecture is conceptually similar to the generative architecture; however, a key difference is that the loss function is applied to the embedding space rather than the input space. JEPA learns an embedding to predict signal y from compatible signal x, using a prediction network with an additional (possibly latent) variable z to facilitate prediction. The paper's proposed I-JEPA provides an instantiation of this architecture in the context of an image using a mask; see Figure 3. In contrast to the joint embedding architecture, JEPA does not seek to augment invariant representations to a set of hand-crafted data, but instead seeks representations that are mutually predictive when additionally informative z-conditions. However, like the joint embedding architecture, representation collapse is also a concern of JEPA. The paper exploits the asymmetric architecture between the x and y encoders to avoid representation collapse in I-JEPA.

 3. Method introduction

The paper now describes the proposed Image-based Joint Embedding Prediction Architecture (I-JEPA), shown in Figure 3. The overall goal is as follows: given a context block, predict the representations of different target blocks in the same image. The paper uses the Visual Transformer (ViT) architecture as the context encoder, target encoder and predictor. A ViT consists of a stack of Transformer layers, and each Transformer layer consists of a self-attention operation and a fully connected MLP. The paper's encoder/predictor architecture is reminiscent of the generative masked autoencoder (MAE) approach. A key difference, however, is that the I-JEPA approach is non-generative and predictions are made in the representation space.

 

 

 

4. Image Classification

To demonstrate that I-JEPA learns high-level representations without relying on handcrafted data augmentation, the paper reports results on various image classification tasks using linear probing and partial fine-tuning protocols. In this section, the paper considers self-supervised models pre-trained on the ImageNet-1K dataset. See Appendix A for pretraining and evaluation implementation details. All I-JEPA models are trained in resolution 224×224, unless explicitly stated otherwise.

ImageNet-1K. Table 1 shows the performance on the common ImageNet-1K linear evaluation benchmark. After self-supervised pre-training, the model weights are frozen and a linear classifier is trained on top using the full ImageNet-1K training set. Compared to popular masked autoencoder (MAE) and data2vec methods, which also do not rely on extensive hand-crafted data augmentation before training, the paper sees that I-JEPA significantly improves linear detection performance while using less calculation amount. In addition, I-JEPA benefits from scale. ViT-H/16 trained at resolution 448 matches the performance of view-invariant methods such as iBOT without additional handcrafted data augmentation.

Low-Shot ImageNet-1K. Table 2 shows the performance on the 1% ImageNet benchmark. The approach here is to use a pre-trained model for ImageNet classification using only 1% of the ImageNet labels, corresponding to about 12 or 13 images per class. The model is tuned by fine-tuning or linear probing, depending on what works best for each method. When using a similar encoder architecture, I-JEPA outperforms MAE while requiring fewer pre-training epochs. I-JEPA, using the ViTH/14 architecture, matches the performance of ViT-L/16 using data 2vec pre-training, while using significantly less computation. By increasing the image input resolution, I-JEPA outperforms previous methods, including joint embedding methods utilizing additional handcrafted data augmentation before training, such as MSN, DINO, and iBOT.

Transfer learning. Table 3 shows the performance on various downstream image classification tasks using linear probes. I-JEPA significantly outperforms previous methods that do not use augmentation (MAE and Data2vec), and reduces the gap by exploiting hand-crafted viewpoint-invariant best methods before training, even surpassing the popular ones on CIFAR100 and Place205 DINO.

5. Local Prediction Tasks

I-JEPA learns semantic image representations, significantly improving the downstream image classification performance of previous methods, such as MAE and data2vec. Furthermore, I-JEPA benefits from scale and can close the gap, or even surpass, the view-invariance-based methods augmented with additional hand-crafted data. In this section, the paper finds that I-JEPA also learns local image features and outperforms view-invariance based methods in low-level and intensive prediction tasks such as object counting and depth prediction.

Table 4 shows the performance of various low-level tasks using linear probing. In particular, after pre-training, the weights of the model are frozen and a linear model is trained on top for object counting and depth prediction on the Clevr dataset. Compared with view-invariant methods such as DINO and iBOT, the I-JEPA method efficiently captures low-level image features before training and outperforms in object counting (Clevr/Count) and (substantially) depth prediction (Clevr/Dist). to them.

 6. Scalability

Model Efficiency Compared with previous methods, I-JEPA is highly scalable. Figure 5 shows the semi-supervised evaluation on 1% ImageNet-1K as a function of GPU hours. I-JEPA requires less computation than previous methods and achieves strong performance without relying on handcrafted data augmentation. Compared to reconstruction-based methods such as MAE, which directly use pixels as targets, I-JEPA introduces additional overhead (about 7% slower per iteration time) by computing the targets in the representation space.

 Scaling data size (Scaling data size). The paper also found that I-JEPA benefits from pre-training on larger datasets. Table 5 shows the transfer learning performance on semantic and low-level tasks when increasing the size of the pre-training dataset (IN1K vs IN22K). Transfer learning performance on these conceptually different tasks improves when pretrained on larger and more diverse datasets.

Scaling model size (Scaling model size). Table 5 also shows that I-JEPA benefits from larger model size when pre-trained on IN22K. Pre-training on ViT-G/16 significantly improves downstream performance on image classification tasks such as Place205 and INat18 compared to the ViT-H/14 model. The ViTG/16 model does not improve performance on low-level downstream tasks. ViT-G/16 uses a larger input patch size, which may be detrimental for local prediction tasks.

 7. Predictor Visualizations

The role of the predictor in I-JEPA is to take the output of the context encoder and, conditioned on the position mask token, predict the representation of the target block at the position specified by the mask token. One question is whether predictors conditioned on positional mask tokens are learning to correctly capture positional uncertainty in objects. To investigate this question qualitatively, the paper visualizes the output of the predictor. After pre-training, the paper freezes the weights of the context encoder and predictor, and trains a decoder following the RCDM framework to map the average pool output from the predictor back to pixel space. Figure 6 shows the decoder output for various random seeds. Features that are common across samples represent the information contained in the mean pooled predictor representation. The I-JEPA predictor correctly captures the uncertainty in position and produces high-level object parts (e.g., the back of a bird and the top of a car) with correct poses. The different mass representations in different samples represent information not contained in the notation. In this case, the I-JEPA predictor discards precise low-level details and background information.

 8. Ablations

Predicting in representation space. Table 7 compares the low-shot performance when computing 1% ImageNet-1K in pixel space and representation space. The paper speculates that a key component of I-JEPA is that the loss is computed entirely in the representation space, enabling the target encoder to produce abstract prediction targets that remove irrelevant pixel-level details. From Table 7, it is clear that prediction in pixel space leads to a significant drop in linear detection performance.

Masking strategy. In Table 8, the paper reduces the number of target blocks and the size of the context and target blocks in the multi-block masking strategy proposed in the I-JEPA pre-training process (as shown in Figure 4). The paper trains for 300 epochs using I-JEPA with various multi-block settings, and compares performance on the 1% ImageNet-1K benchmark using linear probes. In short, the paper finds that it is important to predict several relatively large (semantic) target chunks, and to use informative (spatially distributed) contextual chunks.

Table 6 also performs similar ablation when comparing with other masking strategies. The paper compares to a rasterized masking strategy, where the image is segmented into four large quadrants, and the goal is to use one quadrant as context to predict the other three. The paper also compares traditional block and random masking strategies commonly used in reconstruction-based methods. In block masks, the target is a single image patch and the context is the image complement. In random masks, the target is a set of random (possibly discontinuous) image patches, and the context is the complement of the image. Note that in all considered masking strategies, there is no overlap between context blocks and target blocks. The proposed multi-block masking strategy is the key for I-JEPA to learn semantic representation. Even switching to traditional block masks reduces ImageNet performance by more than 24%.

9. Conclusion

The paper proposes I-JEPA, a simple yet effective method for learning semantic image representations without relying on hand-crafted data augmentation. The paper shows that by making predictions in the representation space, I-JEPA converges faster than pixel reconstruction methods and learns high-semantic level representations. Compared to view-invariance-based approaches, I-JEPA emphasizes a path to learn general representations using a joint embedding architecture without relying on handcrafted view augmentation.

See the original text for the appendix, link to the original text: https://arxiv.org/abs/2301.08243

Guess you like

Origin blog.csdn.net/INTSIG/article/details/132495776