Efficient Image Segmentation with PyTorch: Part 1

1. Description

        In this 4-part series, we'll walk through image segmentation from the ground up using deep learning techniques in PyTorch. In this article we will start the series with the basic concepts and ideas needed for image segmentation.

Figure 1: Pet images and their segmentation masks (Source: Oxford-IIIT Pets Dataset )

 

        Image segmentation is a technique for isolating pixels belonging to specific objects in an image . The isolation of object pixels opens the door to interesting applications. For example, in Figure 1, the image on the right is the mask corresponding to the pet image where the yellow pixels on the left belong to pets. Once the pixels are identified, we can easily make the pet bigger or change the background of the image. This technology is widely used in the face filter function in various social media applications.

        At the end of this series of articles, our goal is to walk the reader through all the steps required to build a vision AI model and run experiments with different settings using PyTorch.

2. This series of articles

        This series is aimed at readers of all deep learning experience levels. If you want to learn about deep learning and visual AI in practice with some solid theoretical and hands-on experience, you're in the right place! This will be a 4-part series with the following articles:

  1. Concepts and Ideas (this article)
  2. CNN-based models
  3. Depthwise Separable Convolution
  4. Visual Transformer-Based Models

3. Introduction to Image Segmentation

        Image segmentation divides or segments an image into regions corresponding to objects, background, and boundaries. See Figure 2, which shows an urban scene. It marks areas corresponding to cars, motorcycles, trees, buildings, sidewalks, and other interesting objects with masks of different colors. These regions are identified by image segmentation techniques.

        Historically, we have decomposed images into regions using specialized image processing tools and pipelines . However, due to the amazing growth of visual data in the past two decades, deep learning has become the preferred solution for image segmentation tasks. It greatly reduces the reliance on experts to construct domain-specific image segmentation strategies, as has been done in the past. Deep learning practitioners can train image segmentation models if sufficient training data is available for the task.

Figure 2: Segmentation scene from the a2d2 dataset (CC BY-ND 4.0)

3.1 What are the applications of image segmentation?

        Image segmentation has applications in many fields such as communication, agriculture, transportation, healthcare, etc. Moreover, its application grows with the growth of visual data. Here are some examples:

  • In self-driving cars , deep-learning models continuously process video feeds from on-board cameras to segment the scene into objects such as cars, pedestrians and traffic lights, which is critical for the car to operate safely.
  • In medical imaging, image segmentation helps doctors identify areas in medical scans that correspond to tumors, lesions, and other abnormalities.
  • In  Zoom video calls , it is used to protect personal privacy by replacing the background with a virtual scene.
  • In agriculture , information on weeds and crop areas identified using image segmentation is used to maintain healthy crop yields .

        You can read more details about image segmentation in action on this page from v7labs .

3.2 What are the different types of image segmentation tasks?

        There are many different types of image segmentation tasks, each with its advantages and disadvantages. The two most common types of image segmentation tasks are:

  • Class or Semantic Segmentation: Class segmentation assigns each image pixel a semantic class, such as background , road , car , or person . If there are 2 cars in the image, the pixels corresponding to the two cars will be labeled as car pixels. It is commonly used in tasks such as autonomous driving and scene understanding .
  • Object or Instance Segmentation: Object segmentation identifies objects and assigns a mask to each unique object in the image. If there are 2 cars in the image, the pixels corresponding to each car will be identified as belonging to separate objects. Object segmentation is often used to track individual objects, such as a self-driving car programmed to follow a specific car ahead.

Figure 3: Object and Class Segmentation (Source: MS Coco - Creative Commons Attribution License ))

In this series, we will focus on class segmentation.

3.3 Decisions needed to implement efficient image segmentation

        Efficiently training models for speed and accuracy requires many important decisions to be made during the lifecycle of a project. This includes (but is not limited to):

  1. Choice of Deep Learning Framework
  2. Choose a good model architecture
  3. Choose an effective loss function to optimize the aspect you care about
  4. Avoid overfitting and underfitting
  5. Evaluate the accuracy of the model

In the rest of this article, we'll explore each of the above aspects in more depth and provide links to a number of articles that discuss each topic in more detail, which can be covered here.

4. PyTorch for Efficient Image Segmentation

4.1 What is PyTorch?

       " PyTorch is an open-source deep learning framework designed to be flexible and modular for research, with the stability and support needed for production deployment. PyTorch provides a Python package for advanced functions such as tensor computation (like NumPy) , with powerful GPU acceleration and TorchScript for easy transition between eager and graph modes. With the latest version of PyTorch, the framework offers graph-based execution, distributed training, mobile deployment, and quantization. (Source : PyTorch on the Meta AI page )

        PyTorch is written in Python and C++, which makes it easy to use and learn as well as efficient to run. It supports a wide range of hardware platforms including (server and mobile) CPUs, GPUs and TPUs.

4.2 Why is PyTorch a good choice for image segmentation?

        PyTorch is a popular choice for deep learning research and development because it provides a flexible and powerful environment for creating and training neural networks. It is an excellent framework choice for implementing deep learning based image segmentation due to the following features:

  • Flexibility : PyTorch is a flexible framework that allows you to create and train neural networks in a variety of ways. You can use pre-trained models or create your own from scratch very easily
  • Backend support : PyTorch supports multiple backends, such as GPU/TPU hardware
  • Domain Libraries : PyTorch has a rich set of domain libraries that make it easy to work with vertically specific data. For example, for AI related to vision (image/video), PyTorch provides a library called Torchvision , which we will use extensively in this series.
  • Ease of use and community adoption : PyTorch is an easy-to-use framework, well-documented, and has a large community of users and developers . Many researchers use PyTorch for their experiments, and the results in their published papers implement the model in PyTorch for free.

5. Dataset Selection

        We will use the Oxford IIIT Pet dataset (licensed under CC BY-SA 4.0) for class segmentation. This dataset has 3680 images in the training set, and each image has a segmentation map associated with it. A triplet is one of 3 pixel classes:

  1. pet
  2. background
  3. border

        We choose this dataset because it is sufficiently diverse to provide us with the important class segmentation task. Also, it's not so complicated that we end up spending our time dealing with things like class imbalance...and forgetting about the main problem we're trying to understand and fix; namely class segmentation.

        Other commonly used datasets for image segmentation tasks include:

  1. Pascal VOC (Visual Object Class)
  2. Ms. Coco
  3. city ​​View

6. Use PyTorch to achieve efficient image segmentation

        In this series, we will train several models for class segmentation from scratch. There are a number of considerations to consider when building and training a model from scratch. Below, we'll cover some of the key decisions you'll need to make in doing so.

6.1 Choosing the right model for your task

        There are many factors to consider when choosing the right deep learning model for image segmentation. Some of the most important factors include:

  • Types of Image Segmentation Tasks : There are two main types of image segmentation tasks: class (semantic) segmentation and object (instance) segmentation. Since we focus on the simpler class segmentation problem, we will consider modeling the problem accordingly.
  • Dataset size and complexity: The size and complexity of the dataset affects the complexity of the models we need to use. For example, if we are dealing with images with small spatial dimensions, we might use a simpler (or shallower) model such as a Fully Convolutional Network (FCN). If we work with large and complex datasets, we may use more complex (or deeper) models such as U-Net.
  • Availability of pretrained models: There are many pretrained models available for image segmentation. These models can be used as a starting point for our own models, or they can be used directly. However, if we use a pre-trained model, we may be limited by the spatial dimensions of the input images to the model. In this series, we will focus on how to train a model from scratch.
  • Available computing resources : Deep learning models can be expensive to train. If our computing resources are limited, we may need to choose a simpler model or a more efficient model architecture.

        In this series, we will use the Oxford IIIT Pet dataset because it is large enough that we can train moderately sized models and requires the use of a GPU. We strongly recommend that you   create an account on kaggle.com , or use Google Colab 's free GPU to run the notebooks and code referenced in this series.

6.2 Model Architecture

        Here are some of the most popular deep learning model architectures for image segmentation:

  • U-Net: U-Net is a convolutional neural network commonly used for image segmentation tasks. It uses skip connections, which helps train the network faster and improves overall accuracy. If you have to choose, U-Net is always an excellent default choice !
  • FCN: Fully Convolutional Network ( FCN ) is a fully convolutional network, but it is not as deep as U-Net . The lack of depth is mainly due to the loss of accuracy at higher network depths. This makes training faster, but may not be as accurate as U-Net.
  • SegNet: SegNet is a popular model architecture similar to U-Net, using less activation memory than U-Net. We will be using SegNet in this series.
  • Visual Transformer (ViT): Visual Transformer has recently gained popularity due to its simple structure and applicability of attention mechanism to text, vision and other domains. Vision transformers can be more efficient (compared to CNNs) in training and inference, but have historically required more data to train than convolutional neural networks. We will also use ViT in this series.

Figure 4: UN et model architecture. 

        These are just a few of the many deep learning models available for image segmentation. The best model for a particular task will depend on the previously mentioned factors, the specific task, and your own experiments.

6.3 Choosing the right loss function

        The choice of loss function for image segmentation tasks is important because it can have a significant impact on the performance of the model. There are many different loss functions available, each with their own advantages and disadvantages. The most commonly used loss functions in image segmentation are:

  • Cross-entropy loss : Cross-entropy loss is a measure of the difference between the predicted probability distribution and the true probability distribution
  • IoU  loss: IoU  loss measures the amount of overlap between the predicted mask and the true mask for each class. IoU loss penalizes cases where prediction or recall would suffer. The defined IoU is non-differentiable, so we need to tweak it slightly to use it as a loss function
  • Dice Loss: Dice loss is also a measure of the overlap between the predicted mask and the ground truth mask.
  • Tversky Loss: Tversky loss is proposed as a robust loss function that can be used to deal with imbalanced datasets.
  • Focal loss: Focal loss aims to focus on hard-to-classify examples. This helps improve the performance of the model on challenging datasets.

        The optimal loss function for a particular task will depend on the specific requirements of the task. For example, if accuracy is more important, IoU loss or dice loss might be better choices. If the task is unbalanced, then Tversky loss or focal loss might be good choices. When training a model, the specific loss function used can affect the convergence rate of the model.

        The loss function is a hyperparameter of the model, using different losses based on the results we see allows us to reduce the loss faster and improve the accuracy of the model.

Default : In this series, we will use the cross-entropy loss as it is always a good default to choose when the outcome is unknown .

You can use the following resources to learn more about loss functions.

  1. PyTorch Loss Functions: The Ultimate Guide
  2. Torch Vision — Loss
  3. Flare Metering

Let's look in detail at the IoU loss we define below, which is a robust alternative to the cross-entropy loss for segmentation tasks.

6.4 Custom IOU Loss

IoU  is defined as intersection over union. For image segmentation tasks, we can compute this by computing (for each class), the intersection of the pixels in that class predicted by the model and the ground truth segmentation mask.

For example, if we have 2 classes:

  1. background
  2. people

We can then determine which pixels are classified as persons and compare them to the person ground truth pixels and compute the IoU for the person class. Similarly, we can calculate the IoU of the background classes.

Once we have these class-specific IoU metrics, we can choose to average them unweighted or weighted, and then average them again to account for any type of class imbalance we saw in the previous examples.

The defined IoU metrics require us to compute hard labels for each metric. This requires the use of the argmax() function, which is not differentiable, so we cannot use this metric as a loss function. Therefore, instead of using hard labels, we apply softmax() and use the predicted probabilities as soft labels to compute the IoU metric. This produces a differentiable metric from which we can then compute the loss. Therefore, sometimes, when used in the context of loss functions, IoU metrics are also referred to as soft IoU metrics.

If we have a metric (M) between 0.0 and 1.0, we can compute the loss (L) as:

L = 1 — M

However, if your metric has values ​​between 0.0 and 1.0, here's another trick you can use to convert the metric to a loss. calculate:

L = -log(M)

That is, the negative logarithm of the calculated indicator. This is a marked departure from the previous formulation, which you can read here and here . Basically, it leads to better learning for your model.

Figure 6: Comparing the loss due to 1-P(x) vs -log(P(x)). Source: Author.

Using IoU as our loss also brings the loss function closer to capturing what we really care about. There are pros and cons to using an evaluation metric as a loss function. If you're interested in exploring this area further, you can start with this discussion on stackexchange .

6.5 Data Augmentation

        To efficiently and effectively train a model to achieve good accuracy, attention needs to be paid to the amount and type of training data used to train the model. The training data you choose to use will significantly affect the accuracy of your final model, so if there's one thing you want to learn from this series of articles, this should be it!

        Usually, we will split the data into 3 parts roughly in the proportions mentioned below.

  1. Training (80%)
  2. Verification (10%)
  3. Test (10%)

        You will train the model on the training set, evaluate the accuracy on the validation set, and repeat the process until you are satisfied with the reported metrics. Only then will you evaluate the model on the test set, and then report the numbers. This is done to prevent any kind of bias from creeping into the model's architecture and hyperparameters used during training and evaluation. In general, the more you adjust settings based on the results of your test data, the less reliable your results will be. Therefore, we must limit our decisions to the results seen on the training and validation datasets.

        In this series, we will not be using a test dataset. Instead, we will use the test dataset as a validation dataset and apply data augmentation to the test dataset so that we always validate the model on slightly different data . This prevents us from overfitting our decisions on the validation dataset. It's a bit of a hack, we do it just for expedients and shortcuts. For production model development, you should try to stick to the standard recipe above.

        The dataset we will be using in this series has 3680 images in the training set. While this may seem like a lot of images, we want to make sure our model doesn't overfit to these images since we'll be training the model over multiple epochs.

        In a single training epoch, we train the model on the entire training dataset, and we typically train models for 60 or more epochs in production. In this series, we will only train the model for 20 epochs to reduce iteration time. To prevent overfitting , we will employ a technique called data augmentation, which is used to generate new input data from existing input data . The basic idea behind data augmentation of image inputs is that if you change the image slightly, it feels like a new image to the model, but you can infer that the expected output is the same. Here are some examples of the data augmentation we will apply in this series.

  1. random horizontal flip
  2. random color dithering

        While we will use the Torchvision library to apply the data augmentation described above, we encourage you to evaluate the Albumentations data augmentation library for vision tasks. Both libraries have a rich set of transformations available for image data. We personally continue to use Torch Vision simply because it was started by us. Albumentations  support richer data augmentation primitives that can simultaneously alter the input image as well as ground truth annotations or masks. For example, if you want to resize or flip an image, you need to make the same changes to the ground truth segmentation mask. Albumentations can do that for you out of the box.

        Broadly speaking, both libraries support transformations that are applied to images at the pixel level or that change the dimensions of the image space. Pixel-level transformation is called color transformation by Torchview, and spatial transformation is called geometric transformation by Torchview.

        Below, we'll see some examples of pixel-level and geometric transformations applied by the Torchvision and Albumentations libraries.

Figure 7: Example of pixel-level data augmentation applied to an image using crosstalk. Source: protein

Figure 8: Example of data augmentation applied to an image using Torchvision transforms. Source: Author ( Notebook )

Figure 9: Example of spatial level transformation applied using protein transformation. Source: Author ( Notebook )

6.6 Evaluating model performance

        When evaluating a model's performance, you need to understand how it performs on metrics that represent the quality of the model's performance on real data. For example, for an image segmentation task, we want to know how accurately the model can predict the correct class for a pixel. Therefore, we say that pixel accuracy is the validation metric for this model.

        You can use your evaluation metric as a loss function (why not optimize something you really care about!), except that might not always be possible .

        In addition to accuracy , we will also track the IoU metric (also known as  the Jaccard index ) and the custom IoU metric we defined above.

        To learn more about various accuracy metrics available for image segmentation tasks, see:

6.7 Disadvantages of using pixel precision as a performance metric

        While the accuracy metric may be a good default choice for measuring performance on image segmentation tasks, it does have its own drawbacks which may be important depending on your specific situation.

        For example, consider an image segmentation task to identify the eyes of people in a picture and label those pixels accordingly. Therefore, the model will classify each pixel as one of the following:

  1. background
  2. Eye

        Assume that there is only 1 person in each image and that 98% of the pixels do not correspond to eyes. In this case, the model can simply learn to predict each pixel as a background pixel and achieve 98% pixel accuracy on the segmentation task. Wow!

Figure 10: The corresponding segmentation masks for face images and their eyes. You can see that the eye is only a very small part of the overall image. Source: Adapted from Unsplash

        In this case, it might be a better idea to use IoU or Dice metrics, as IoU will capture how much the prediction was correct and will not necessarily be biased by the area each class or class occupies in the original image. You could even consider using IoU or dice coefficient per class as a metric. This better captures the performance of the model on the task at hand.

When only pixel precision is considered, the precision and recall         of the object for which we are computing the segmentation mask (the eye in the example above) can capture the detail we are looking for.

        Now that we've covered most of the theoretical underpinnings of image segmentation, let's take a detour and look at considerations relevant to the inference and deployment of image segmentation for real-world workloads.

6.8 Model Size and Inference Latency

        Last but not least, we want to make sure our model has a reasonable number of parameters, but not too many, since we want a small and efficient model. We will examine this aspect in more detail in future articles related to reducing model size using efficient model architectures .

        What matters in terms of inference latency is the number of math operations (multiple additions) our model performs. Both model size and multiple additions can be displayed using the torch info package. While multi-addition is a good proxy for determining model latency, latency can vary widely across backends. The only real way to determine how a model will perform on a specific backend or device is to profile and benchmark it on that specific device with the set of inputs you expect to see in a production setting.

from torchinfo import summary
model = nn.Linear(1000, 500)
summary(
  model,
  input_size=(1, 1000),
  col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
  col_width=15,
)

output:

====================================================================================================
Layer (type:depth-idx)                   Kernel Shape    Output Shape    Param #         Mult-Adds
====================================================================================================
Linear                                   --              [1, 500]        500,500         500,500
====================================================================================================
Total params: 500,500
Trainable params: 500,500
Non-trainable params: 0
Total mult-adds (M): 0.50
====================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 2.00
Estimated Total Size (MB): 2.01
====================================================================================================

7. Article review

        Here's a quick recap of what we've discussed so far.

  • Image segmentation is a technique for dividing an image into segments (source: Wikipedia )
  • There are two main types of image segmentation tasks: class (semantic) segmentation and object (instance) segmentation. Class segmentation assigns each pixel in an image to a semantic class. Object segmentation identifies each individual object in an image and assigns a mask to each unique object
  • We will use PyTorch as the deep learning framework and the Oxford IIIT Pet dataset in this series of efficient image segmentation.
  • There are many factors to consider when choosing the right deep learning model for image segmentation, including (but not limited to) the type of image segmentation task, the size and complexity of the dataset, the availability of pretrained models, and available computing resources. Some of the most popular deep learning model architectures for image segmentation include U-Net, FCN, SegNet, and Vision Transformer (ViT).
  • The choice of loss function for image segmentation tasks is important because it can have a significant impact on model performance and training efficiency. For image segmentation tasks, we can use cross-entropy loss, IoU loss, dice loss or focal loss (among others)
  • Data augmentation is a valuable technique for preventing overfitting as well as dealing with insufficient training data
  • Assessing the performance of the model is important for the task at hand and this metric must be chosen carefully
  • Model size and inference latency are important metrics to consider when developing a model, especially if you intend to use it for real-time applications such as face segmentation or background noise removal

        In the next post , we'll look at a Convolutional Neural Network (CNN), built from scratch using PyTorch, to perform image segmentation on the Oxford IIIT Pet dataset.

Guess you like

Origin blog.csdn.net/gongdiwudu/article/details/132339606