SAM + YOLOv8 image segmentation and object detection

SAM (Segment Anything Model) is a deep learning model created and trained by Meta's team of researchers. The innovation was published in a research paper published on April 5, 2023, and it immediately aroused widespread public interest - the related Twitter post has accumulated more than 3.5 million views to date:

Computer vision professionals are now turning their attention to SAM—but why?

Insert image description here

Recommended: Use NSDT editor to quickly build programmable 3D scenes

1. What is SAM?

In the Segment everything research paper, SAM is called the "base model".

A base model is a machine learning model trained on large amounts of data (usually via self-supervised or semi-supervised learning) with the intention of being used and retrained on more specific tasks.

In other words, a SAM is a pretrained model designed to be adapted to other tasks (especially through fine-tuning).

For example, SAM can be retrained and used to segment only people in the dataset.

Person segmentation is an adjunct task that SAM can perform because it has been trained on datasets containing such objects - but not only!

2. How is SAM trained?

SAM was trained on the SA-1B dataset, introduced by Meta in parallel with the Segment Anything research paper.

The Facebook parent company's dataset contains more than 11 million images collected from nearly the entire planet - an important aspect of developing models with the ability to generalize.

Insert image description here

Images collected from nearly the entire Earth – SA-1B dataset

These high-quality images (average 1500 × 2250 pixels) are accompanied by 1.1 billion segmentation masks corresponding to the dataset labels.

Meta's goal with this dataset is to create a segmented reference for AI PhDs. It has an official free license for research purposes.

Although very informative, it is worth noting that the masks are category independent. In other words, even if SAM can generate a mask of a person, it cannot show that this mask represents a person.

This is an important point to keep in mind, as it means that SAM must be combined with other algorithms to really be useful.

Let's take a closer look.

3. How to use SAM?

First, we need to load 2 items:

  • segment-anything GitHub folder containing classes and functions that use SAM
  • Pre-trained model weights using a version of the model obtained by Meta researchers
!pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Next, we create 3 global variables:

  • MODEL_TYPE: SAM architecture to use
  • CHECKPOINT_PATH: Path to the file containing model weights
  • DEVICE: Processor used, "cpu" or "cuda" (if GPU is available)
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "/content/sam_vit_h_4b8939.pth"
DEVICE = "cuda" #cpu,cuda

We can now load the SAM model using the sam_model_registry function, indicating the model weights:

from segment_anything import sam_model_registry

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

After the model is loaded, Meta provides us with two usage options:

  • Generator option that allows you to get all masks generated by the model from the image
  • Predictor option, which allows us to get one or more specific masks from the image based on prompts.

We will explore both options in the following lines.

Before that, let's load an image from the internet on which we will experiment our model:

from urllib.request import urlopen
import cv2
import numpy as np
from google.colab.patches import cv2_imshow

resp = urlopen('https://images.unsplash.com/photo-1615948812700-8828458d368a?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=2072&q=80')
image = np.asarray(bytearray(resp.read()), dtype='uint8')
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
image = cv2.resize(image, (int(image.shape[1]/2.5), int(image.shape[0]/2.5)))

cv2_imshow(image)

Insert image description here

Our image contains several people, a dog, and some cars.

Now we will segment this image using SAM and Generator options.

4. Generator

In this section, we'll use the generator version of SAM. This will allow us to obtain a set of masks generated as a result of the model's analysis of the image.

Let's initialize the SamAutomaticMaskGenerator object:

from segment_anything import SamAutomaticMaskGenerator

mask_generator = SamAutomaticMaskGenerator(sam)

Next, we start mask generation using the generate() function:

masks_generated = mask_generator.generate(image)

This function generates a mask along with other data for each detected object. SAM actually generates a set of information (in dictionary form) related to the objects it detects.

5. Forecast results

We can display the keys obtained for each set of information:

print(masks_generated[0].keys())

Output:

dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

The result is a set of 7 messages. The first "segment" represents the pixel corresponding to the position of the detected object: True if the pixel contains the object, False otherwise.

The mask can be displayed as follows:

cv2_imshow(masks_generated[3]['segmentation'].astype(int)*255)

Insert image description here

Other information in this collection corresponds to the following descriptions:

  • area: mask area (in pixels)
  • bbox: mask bounding box in XYWH format
  • Predicted_iou: Mask quality score predicted by the model
  • point_coords: The sampled input points that generated this mask
  • stable_score: additional mask quality score
  • Crop_box: Image crop used to generate this mask in XYWH format

Most practitioners will not use this information, but for certain cases it is important to know that SAM not only generates masks, but also generates additional information such as this.

Here is the rest of the information obtained for the mask shown above:

print('area :', masks_generated[3]['area'])
print('bbox :',masks_generated[3]['bbox'])
print('predicted_iou :',masks_generated[3]['predicted_iou'])
print('point_coords :',masks_generated[3]['point_coords'])
print('stability_score :',masks_generated[3]['stability_score'])
print('crop_box :',masks_generated[3]['crop_box'])

Output:

area : 5200 bbox : [499, 284, 92, 70]
predicted_iou : 1.005275845527649
point_coords : [[582.1875, 318.546875]]
stability_score : 0.981315553188324
crop_box : [0, 0, 828, 551]

We can also display the number of masks generated by SAM:

print(len(masks_generated))

Output:

111

SAM generated a total of 111 masks from our images.

6. Display predictions

Using the draw_masks_fromDict function introduced in this article, we can draw all the masks generated on the image:

segmented_image = draw_masks_fromDict(image, masks_generated)

cv2_imshow(segmented_image)

Insert image description here

The starting image now contains the mask generated by the SAM.

In this section, we use the generator version of SAM. This allows us to generate 111 masks from the image. In addition to masks, SAM generates additional detection information. To visualize the model's predictions, we finally plot all masks on the starting image.

Therefore, SAM enables us to perform image segmentation. However, we can see that the generated masks are unordered: there is no classification to distinguish different masks. For example, people's masks are not associated with a single color. Therefore, the resulting segments cannot be sorted. The only information obtained here is the position and bounds of the object.

Additionally, the generated masks can overlap. In fact, SAM can detect objects inside other objects. On the positive side, this demonstrates that SAM is capable of detecting almost any object in an image. This means we can segment dogs, cars, people, and other objects such as wheels, windows, or pants. Therefore, the generator version of SAM is able to segment all objects in the image, even overlapping objects.

7. Beyond Generator

However, this feature also has a downside: it increases the number of predictions in a given area, which can undermine the achievement of goals. For example, if you want to detect a person in an image, it doesn't matter if you also detect the mask corresponding to its jacket and pants.

Furthermore, since SAM is not trained on labeled data, it is not possible to filter its predictions to retain the predictions we are interested in. This means that even if we segment all images in the dataset using the generator version of SAM, it is not possible to easily extract masks for, for example, people. Therefore, the SAM generator's ability to segment all objects in an image may not be suitable for solving some problems.

Therefore, for target object detection, it is not suitable to use the generator version of SAM. Instead, we need to use the predictor version. This release will enable us to use SAM and prompt us to specify our requests and target objects to instrument.

8. Predictor

In this section we will use the predictor version of SAM. The predictor version will enable us to detect target objects. To do this, we will send a SAM prompt to specify the object we want to detect.

Currently, there are two ways to send prompts to SAM:

  • by points of interest
  • by bounding box

SAM can take as input points of interest (x and y coordinates) for image pixels that represent objects. The object specified by the point of interest will then enable SAM to generate a mask associated with this object.

SAM can also take as input bounding boxes that separate the outlines of objects in the image. Based on these contours, SAM will generate an appropriate mask.

Note: "Prompt" is a buzzword used in most cases to refer to text requests sent to ChatGPT. However, as SAM shows, prompts are not limited to text requests. It extends to a set of queries that practitioners can send to machine learning models.

It is important to note that while this feature is not currently public, Meta already provides for text request understanding through its segmentation of any model.

That is, for the rest of this tutorial, we need to have a prompt sent to SAM. Bounding boxes are a computer vision standard, so we'll use them.

9. Use bounding box hints

If you want to continue with this tutorial, you must first have a bounding box associated with the object you want to segment.

If you don't have bounding boxes for your image, you can easily generate them in a few lines of code using the YOLO template.

You can learn how to quickly generate your own bounding boxes using this template. Tutorials specifically for the latest version of YOLO are waiting for you here.

Once we use YOLO on our image, we get something like this:

image_bboxes = image.copy()

boxes = np.array(results[0].to('cpu').boxes.data)

plot_bboxes(image_bboxes, boxes, score=False)

Insert image description here

Note: The outcome variable is the outcome predicted by the model.

The bounding box obtained using YOLO takes the following form:

print(boxes)

Output:

[[ 495.96 285.65 589.8 356.48 0.89921 2]
[ 270.63 147.99 403.17 496.82 0.79781 0]
…
[ 235.32 279.23 508.93 399.63 0.3193 2]
[ 612.13 303.94 647.61 333.11 0.2854 2]]

The first 4 values ​​represent the bounding box coordinates, the 5th value represents the confidence score of the predicted bounding box, and the 6th value represents the detected class.

Now that we have the hint, let's initialize the SamPredictor object:

from segment_anything import SamPredictor

mask_predictor = SamPredictor(sam)

Next, we specify the image to be analyzed by SAM:

mask_predictor.set_image(image)

From here on, the tutorial is divided into two parts:

  • single object detection
  • Batch Object Detection

Let's start with the first option.

10. Detecting Single Objects

To predict the mask of an object, we tell the Predictor in the predict() function the bounding box corresponding to the object:

mask, _, _ = mask_predictor.predict(
    box=boxes[1][:-2]
)

We get a mask in the form of a boolean array indicating the position of the detected object (as before in the "split" key of the dictionary): True if the pixel contains an object, False otherwise.

We can draw this mask on the image using the draw_mask function described in this post:
Insert image description here

Our highlight now contains the mask detected by SAM.

Thanks to the tip given to SAM, we have been able to get the mask of the object and display it on our image.

Now let's see how to detect masks corresponding to all bounding boxes.

11. Detect multiple objects

In order to make predictions on a set of bounding boxes, we need to collect them into a PyTorch tensor.

We then use transform.apply_boxes_torch() to update our object.

Finally, we use predict_torch to predict the corresponding mask.

import torch

input_boxes = torch.tensor(boxes[:, :-2], device=mask_predictor.device)

transformed_boxes = mask_predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = mask_predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

The result is a batch of 13 masks encoded in one dimension (1, 551, 828).

To better manipulate this tensor, let's remove the first irrelevant dimension:

print(masks.shape)
masks = torch.squeeze(masks, 1)
print(masks.shape)

Output:

torch.Size([13, 1, 551, 828])
torch.Size([13, 551, 828])

The advantage of using bounding boxes upstream in SAM is that we can associate each generated mask with a label corresponding to the bounding box, thereby using color to distinguish them when displayed.

Let's define a color gradient associated with a class that YOLO can predict:

COLORS = [(89, 161, 197),(67, 161, 255),(19, 222, 24),(186, 55, 2),(167, 146, 11),(190, 76, 98),(130, 172, 179),(115, 209, 128),(204, 79, 135),(136, 126, 185),(209, 213, 45),(44, 52, 10),(101, 158, 121),(179, 124, 12),(25, 33, 189),(45, 115, 11),(73, 197, 184),(62, 225, 221),(32, 46, 52),(20, 165, 16),(54, 15, 57),(12, 150, 9),(10, 46, 99),(94, 89, 46),(48, 37, 106),(42, 10, 96),(7, 164, 128),(98, 213, 120),(40, 5, 219),(54, 25, 150),(251, 74, 172),(0, 236, 196),(21, 104, 190),(226, 74, 232),(120, 67, 25),(191, 106, 197),(8, 15, 134),(21, 2, 1),(142, 63, 109),(133, 148, 146),(187, 77, 253),(155, 22, 122),(218, 130, 77),(164, 102, 79),(43, 152, 125),(185, 124, 151),(95, 159, 238),(128, 89, 85),(228, 6, 60),(6, 41, 210),(11, 1, 133),(30, 96, 58),(230, 136, 109),(126, 45, 174),(164, 63, 165),(32, 111, 29),(232, 40, 70),(55, 31, 198),(148, 211, 129),(10, 186, 211),(181, 201, 94),(55, 35, 92),(129, 140, 233),(70, 250, 116),(61, 209, 152),(216, 21, 138),(100, 0, 176),(3, 42, 70),(151, 13, 44),(216, 102, 88),(125, 216, 93),(171, 236, 47),(253, 127, 103),(205, 137, 244),(193, 137, 224),(36, 152, 214),(17, 50, 238),(154, 165, 67),(114, 129, 60),(119, 24, 48),(73, 8, 110)]

Finally, we can use the draw_masks_fromList function developed in this article to draw all of our masks, associating a color to each label:

segmented_image = draw_masks_fromList(image, masks.to('cpu'), boxes, COLORS)

cv2_imshow(segmented_image)

Insert image description here

We show all masks predicted by YOLO using the provided bounding boxes. Additionally, each mask is colored according to the class indicated by the bounding box. This makes it easy to distinguish various segmentation objects.


Original link: SAM+YOLOv8 concise tutorial—BimAnt

Guess you like

Origin blog.csdn.net/shebao3333/article/details/132787767