Efficient Image Segmentation with PyTorch: Part 4

1. Description

        In this 4-part series, we'll walk through image segmentation from the ground up using deep learning techniques in PyTorch. This section will focus on how to implement a visual transformer-based image segmentation model.

 

Figure 1: Results of running image segmentation using the Vision Transformer model architecture.

        From top to bottom, the input image, the ground truth segmentation mask, and the predicted segmentation mask. Source: Author

2. Article Outline

        In this article, we will take a tour of the Transformer architecture that has taken the deep learning world by storm . Transformer is a multimodal architecture that can model different modalities such as speech, vision, and audio.

        In this article, we will

  1. Learn about transformer architecture and the key concepts involved
  2. Understanding the Vision Transformer Architecture
  3. Introducing a Visual Converter model written from scratch so you can appreciate all the building blocks and moving parts
  4. Keep track of the input tensor fed to this model and check how it changes shape
  5. Use this model to perform image segmentation on the Oxford IIIT Pets dataset
  6. Observe the result of this segmentation task
  7. A brief introduction to SegFormer, a state-of-the-art visual transformer for semantic segmentation

        In this article, we will refer to the code and results in this notebook for model training. If you want to reproduce the results, you'll need a GPU to ensure that the first notebook runs in a reasonable amount of time.

3. This series of articles

        This series is aimed at readers of all deep learning experience levels. If you want to learn about deep learning and visual AI in practice with some solid theoretical and hands-on experience, you're in the right place! This will be a 4-part series with the following articles:

  1. concepts and ideas
  2. CNN-based models
  3. Depthwise Separable Convolution
  4. Vision Transformer Based Models (this paper)

        Let's start our visual Transformer tour with an introduction and intuitive understanding of the Transformer architecture.

4. Transformer architecture

        We can think of the transformer architecture as a combination of interleaved communication and computation layers. Figure 2 visually depicts this idea. The transformer has N processing units (N in Figure 3 is 2), and each unit is responsible for processing 1/N parts of the input. In order for these processing units to produce meaningful results, each processing unit needs to have a global view of the input . Thus, the system repeatedly communicates information about the data in each processing unit to every other processing unit; this is shown using red, green, and blue arrows from each processing unit to every other processing unit. Next are some calculations based on this information. After repeating this process sufficiently, the model was able to produce the expected results.

Figure 2: Interleaved communication and computation in a transformer. The image shows only 2 layers communication and computation.

        It's worth noting that most online resources generally discuss transformer encoders and decoders, as described in a paper titled " Attention is All You Need" . However, in this article we will only describe the encoder portion of the transformer.

        Let's take a closer look at what constitutes communication and computation in Transformer.

4.1 Communication in Transformer: Note

        In Transformer, communication is implemented by layers called attention layers. In PyTorch, this is called  MultiHeadAttention . We'll talk about the reason for this name later.

        The docs say:

"Allowing the model to jointly focus on information from different representation subspaces, as stated in the paper: Attention is all you need .

        The attention mechanism takes an input tensor  x of shape (batch, length, features) and produces a tensor  y of similar shape , so that the features of each input are updated based on the other inputs the tensor attends to in the same instance. So, in instances of size 'length', the features of each tensor of length 'features' are updated with respect to every other tensor. This is where the quadratic cost of the attention mechanism comes in.

Figure 3: Attention for the word "it" displayed relative to other words in the sentence. We can see that "it" is paying attention to the words "animal", "too" and "tire(d)" in the same sentence. 

        In the context of vision transformers, the input to the transformer is an image. Let's say this is a 128 x 128 (width, height) image. We split it into multiple smaller sized chunks (16 x 16). For a 128 x 128 image, we get 64 patches (length), 8 patches per row and 8 patches per row.

        Each of these 64 blocks of size 16 x 16 pixels is considered as a separate input to the transformer model. Without going into details, it's suffice to think of this process as being driven by 64 different processing units, each processing a single 16x16 image patch.

        In each round, the attention mechanism in each processing unit is responsible for looking at the image patch it is responsible for and querying each of the remaining 63 processing units for any information they may be relevant and useful to help it be efficient process its own image patches.

        The communication step through attention is followed by computation, which we examine next.

4.2 Computation in Transformers: Multilayer Perceptrons

        The computation in Transformer is nothing more than a multi-layer perceptron (MLP) unit. The unit consists of 2 linear layers with a GeLU nonlinearity in between. Other non-linearities can also be considered. The unit first projects the input to 4x the size, then reprojects it back to 1x, which is the same size as the input.

        In the code we'll see in the notebook, this class is called a multilayer perceptron. The code is shown below.

class MultiLayerPerceptron(nn.Sequential):
    def __init__(self, embed_size, dropout):
        super().__init__(
            nn.Linear(embed_size, embed_size * 4),
            nn.GELU(),
            nn.Linear(embed_size * 4, embed_size),
            nn.Dropout(p=dropout),
        )
    # end def
# end class

        Now that we understand how the Transformer architecture works at a high level, let's focus on the Vision Transformer as we will perform image segmentation.

5. Vision Converter

        Vision Transformers were originally introduced by a paper titled "Images Worth 16x16 Words: Transformers for Large-Scale Image Recognition ". This article discusses how the authors applied the vanilla Transformer architecture to the problem of image classification. This is done by splitting the image into patches of size 16x16 and treating each patch as an input token to the model. The Transformer-Encoder model is fed these input tokens and asked to predict the class of the input image.

Figure 4: Source: Transformers for large-scale image recognition .

        In our case we are interested in image segmentation. We can think of it as a pixel-level classification task, since we intend to predict the target class for each pixel.

        We made a small but important change to the vanilla Vision Converter and replaced the MLP header so that pixel-wise classification is done by the MLP header. We have a linear layer in the output, shared by each patch, whose segmentation mask is predicted by the vision transformer. This shared linear layer predicts a segmentation mask for each patch sent as input to the model.

        In the case of a visual transformer, a patch of size 16x16 is considered equivalent to a single input token at a certain time step.

Figure 5: End-to-end working of the Vision Transformer for image segmentation. Images generated using this notebook .

5.1 Intuition for building tensor dimensions in a visual converter

        When using deep CNNs, the tensor dimensions we mostly use are (N, CH, W), where the letters stand for the following:

  • N: batch size
  • C: number of channels
  • H: Height
  • W: width

        You can see this format geared towards 2D image processing, as it smells very specific to the characteristics of the image.

        On the other hand, with transformers, things become more generic and domain-independent. What we'll see below applies to vision, text, NLP, audio, or other problems where the input data can be represented as a sequence. Remarkably, there is little vision-specific bias in the representation of tensors as they flow through our vision converter.

        In using transformers and in general, we expect tensors to have the following shape: (B, T, C), where the letters represent the following:

  • B: batch size (same as CNN)
  • T: time dimension or sequence length. This dimension is also sometimes referred to as L. In the case of a visual transformer, each image patch corresponds to this dimension. If we have 16 image patches then the value of T dimension will be 16
  • C: channel or embedding size dimension. This dimension is also sometimes referred to as E. When processing an image, each patch of size 3x16x16 (channels, width, height) is mapped to an embedding of size C via a patch embedding layer. We'll see how to do this later.

        Let's take a deep dive into how input image tensors are mutated and processed in the process of predicting segmentation masks.

5.2 The Journey of Tensors in Vision Converter

        In a deep CNN, the tensor's journey looks like this (in UNet, SegNet or other CNN based architectures).

        The input tensor is usually of shape (1, 3, 128, 128). This tensor goes through a series of convolutions and max-pooling operations, where its spatial dimension is reduced and its channel dimension is increased, typically by a factor of 2 each. This is called a feature encoder. After this, we perform the reverse operation, increasing the spatial dimension and decreasing the channel dimension. This is called a feature decoder. After the decoding process, we get a tensor of shape (1, 64, 128, 128). This is then projected into our desired number of output channels C, using a 1x128 unbiased pointwise convolution as (128, C, 1, 1).

Figure 6: Typical progression of tensor shapes through a deep CNN for image segmentation. 

        When using a vision transformer, the process is much more complicated. Let's take a look at one of the images below and try to understand how tensors transform shape at each step.

Figure 7: Typical progression of tensor shapes through a visual transformer for image segmentation. 

        Let's look at each step in more detail and see how it updates the shape of the tensors flowing through the vision transformer. To understand this better, let's take concrete values ​​for tensor dimensions.

  1. Batch Normalization: Input and output tensors have shape (1, 3, 128, 128). The shape remains the same, but the values ​​are normalized to zero mean and unit variance.
  2. Image to patch: An input tensor of shape (1, 3, 128, 128) is converted to stacked blocks of 16x16 images. The output tensor has shape (1, 64, 768).
  3. Patch embedding: The patch embedding layer maps 768 input channels to 512 embedding channels (in this example). The shape of the output tensor is (1, 64, 512). The patch embedding layer is basically just a NN. Linear layers in PyTorch.
  4. Position Embedding: The position embedding layer takes no input tensor, but effectively contributes a learnable parameter (a trainable tensor in PyTorch) with the same shape as the patch embedding. This is of shape (1, 64, 512).
  5. Add: The patches and position embeddings are added together piecewise to produce the input to the Vision Transformer encoder. The shape of this tensor is (1, 64, 512). You'll notice that the main workhorse of the vision transformer, the encoder, essentially keeps this tensor shape invariant.
  6. Transformer encoder: An input tensor of shape (1, 64, 512) flows through multiple transformer-encoder blocks, each with multiple attention heads (communication), followed by an MLP layer (computation). The tensor shape remains the same, such as (1, 64, 512).
  7. Linear output projection: If we assume that each image is to be divided into 10 classes, then we need 10 channels per patch of size 16x16. The nn.linear layer for output projection will now convert 512 embedding channels into 16x16x10 = 2560 output channels, this tensor will be like (1, 64, 2560). In the above graph C' = 10. Ideally this would be a multilayer perceptron since " MLPs are universal function approximators " , but we use a single linear layer as this is an educational exercise
  8. Patch to Image: This layer converts 2560 patches encoded as (64, 1, 64) tensors back into something that looks like a segmentation mask. This can be 10 single-channel images, or in this case a single 10-channel image, where each channel is a segmentation mask for one of the 10 classes. The output tensor has shape (1, 10, 128, 128).

         That's it — we've successfully segmented the input image using the Vision Transformer! Next, let's look at an experiment and some results.

5.3 Practical Applications of Vision Transformers

        This notebook contains all the code for this section.

        In terms of code and class structure, it closely mimics the block diagram above. Most of the concepts mentioned above correspond 1:1 to the class names in this notebook .

        There are a few concepts related to attention layers that are key hyperparameters of our model. We didn't mention the details of the bulls' focus earlier, as mentioning it was beyond the scope of this article. If you do not have a basic understanding of attention mechanisms in Transformers, we strongly recommend reading the above references before proceeding.

        We use the following model parameters for the Vision Transformer for segmentation.

  1. 768 embedding dimensions for the patch embedding layer
  2. 12 Transformer encoder block
  3. 8 attention heads in each transformer encoder block
  4. Multi-head attention and 20% dropout in MLP

This configuration can be seen in the VisionTransformerArgs Python data class.

@dataclass
class VisionTransformerArgs:
    """Arguments to the VisionTransformerForSegmentation."""
    image_size: int = 128
    patch_size: int = 16
    in_channels: int = 3
    out_channels: int = 3
    embed_size: int = 768
    num_blocks: int = 12
    num_heads: int = 8
    dropout: float = 0.2
# end class

A similar configuration as before         was used during model training and validation . The configuration is specified as follows.

  1. Random horizontal flipping and color jittering data augmentation is applied to the training set to prevent overfitting
  2. Resize image to 128x128 pixels in non-aspect ratio preserving resize operation
  3. Does not apply any input normalization to the images, but instead uses a batch normalization layer as the first layer of the model
  4. The model is trained for 12 epochs using the Adam optimizer with an LR of 50.0 and the StepLR scheduler that decays the learning rate by a factor of 0.8 every 0004 epochs
  5. The cross-entropy loss function is used to classify pixels as belonging to pet, background or pet border

        The model has 86.28M parameters and has a validation accuracy of 89.50% after 85 training epochs. This is lower than the 28.20% accuracy achieved by the deep CNN model after 88 training epochs. This may be due to some factors that need to be verified experimentally.

  1. The last output projection layer is a single NN. Linear instead of multilayer perceptron
  2. 16x16 patch size is too large to capture finer-grained details
  3. Insufficient training period
  4. Not enough training data - Transformer models are known to require more data to train effectively than deep CNN models
  5. Learning rate is too low

We draw a gif showing how the model learns to predict segmentation masks for the 21 images in the validation set.

Figure 8: A gif showing the progress of the segmentation mask predicted by the visual transformer of the image segmentation model. 

        We noticed something interesting during the early training period. The predicted segmentation mask has some weird blocking artifacts. The only reason we can think of is that we decompose the image into patches of size 16x16 and after very few training epochs the model doesn't learn anything useful about this 16x16 patch usually petted or background pixel overlay.

Figure 9: When using a visual transformer for image segmentation, the occlusion artifacts seen in the predicted segmentation are masked. 

        Now that we have seen a basic visual transformer, let us turn our attention to state-of-the-art visual transformers for segmentation tasks.

5.4 SegFormer: Semantic Segmentation Using Transformers

        This paper  proposes the SegFormer architecture in 2021. The converter we saw above is a simplified version of the SegFormer architecture.

Figure 10: SegFormer architecture. source: 

        Most notably, SegFormer:

  1. Generate 4 sets of images with patches of size 4x4, 8x8, 16x16 and 32x32 instead of a single patch image with patches of size 16x16
  2. Use 4 transformer encoder blocks instead of just 1. This feels like a model ensemble
  3. Using convolutions in the pre- and post-stages of self-attention
  4. Do not use position embedding
  5. Each transformer module processes images at spatial resolutions H/4 x W/4, H/8 x W/8, H/16 x W/16 and H/32, W/32
  6. Likewise, channels increase as the spatial dimension decreases. This feels similar to a deep CNN
  7. Upsampling of predictions in multiple spatial dimensions and then merging them together in the decoder
  8. The MLP combines all these forecasts to provide the final forecast
  9. The final prediction is in spatial dimension H/4,W/4, not in H,W.

6. Conclusion

In Part 4 of this series, we covered Transformer Architecture and Vision Transformers in particular. We gain an intuitive understanding of how Vision Transformers work and the basic building blocks involved in the Communication and Computation phases of Vision Transformers. We saw a unique patch-based approach taken by visual transformers for predicting segmentation masks and then combining the predictions together.

We review an experiment that shows visual transformers in action and enables comparison of results with deep CNN approaches. While our vision converter is not state-of-the-art, it is able to achieve reasonably good results. We provide a glimpse into state-of-the-art methods such as SegFormer.

It should be clear by now that Transformers have more moving parts and are more complex than deep CNN based methods. From a raw FLOP perspective, the transformer promises to improve efficiency. In Transformer, the only real layer that is computationally heavy is nn. linear. This is implemented using optimized matrix multiplication on most architectures. Due to the simplicity of this architecture, Transformer is expected to be easier to optimize and speed up compared to deep CNN-based methods.

Congratulations on getting this far! We're glad you enjoyed reading our series on Efficient Image Segmentation in PyTorch. If you have any questions or comments feel free to leave them in the comments section.

7. Extended reading

The details of the attention mechanism are beyond the scope of this paper. In addition, you can refer to many high-quality resources to learn more about attention mechanisms. Here are some that we highly recommend.

  1. Schematic Transformer
  2. NanoGPT from scratch using PyTorch

We'll provide links below to articles that provide more details on visual converters.

  1. Implementing a Vision Transformer (ViT) in PyTorch  : This paper details the implementation of a Vision Transformer for image classification in PyTorch . It's worth noting that their implementation uses einops, which we avoid as this is an education-focused exercise (we recommend learning and using  einops  to improve code readability). We instead use native PyTorch operators to permutate and rearrange tensor dimensions. Also, the authors use Conv2d instead of linear layers in some places. We wish to build an implementation of a vision transformer that does not use convolutional layers at all.
  2. Vision Converter: Summer of AI
  3. Implementing SegFormer in PyTorch

Drew Mattani

·

 

Guess you like

Origin blog.csdn.net/gongdiwudu/article/details/132339123