Table of contents
- foreword
- Install the operating environment
- How to use the SAM model
-
- Import related libraries and define display functions
- Import the image to be divided
- Object Segmentation Using Different Hint Methods
-
- Approach 1: Target segmentation using a single cue point
- Method 2: Target segmentation using multiple cue points
- Method 3: Use a box to specify a target for segmentation
- Method 4: Combine points and boxes for target segmentation
- Method 5: Input multiple boxes at the same time for multi-target segmentation
- Summarize
This article mainly introduces how to use the SAM model: how to use different hints for target segmentation. Moreover, the model can run quickly in the CPU environment, which is really good~, hurry up and try it
Regarding the relevant code of the Segment-Anything model, the PDF of the paper, the pre-training model, the usage method, etc., I have packaged it up for the exchange and research of the small partners who need it. The way to obtain it is as follows :
Follow the business card GZH at the end of the article: Axu Algorithm and Machine Learning, reply: [SAM] You can get SAM related codes, papers, pre-trained models, usage documents, etc.
foreword
Recently, GPT has been hotly fired. I didn’t expect to see a large CV model so soon, and it has a new data set + new paradigm + super zero-sample generalization ability.
Although the CV large model that appeared this time is not as powerful as the GPT in NLP: one model can handle N multiple downstream tasks. But this is also a good start, and it should also be the future development trend of CV.
The emergence of SAM (Segment-Anything Model) unifies the downstream application of the segmentation task (a subset of the CV task), indicating that a large CV model is possible. It will definitely bring huge changes to the research of CV. Many tasks will be processed in a unified way. It may not be long before detection, segmentation and tracking will be all in one.
Project address: https://github.com/facebookresearch/segment-anything
Demo: https://segment-anything.com/
Install the operating environment
Running requires python>=3.8, and pytorch>=1.7 and torchvision>=0.8.
Install dependent libraries:
pip install git+https://github.com/facebookresearch/segment-anything.git
How to use the SAM model
Import related libraries and define display functions
The following imports the third-party libraries required to run, and defines the functions used to display points, boxes, and segmentation targets.
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
Import the image to be divided
image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
Object Segmentation Using Different Hint Methods
First, load the SAM pre-trained model. [ All the files have been packaged at the end of the article, and interested friends can obtain them by themselves ]
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "./models/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cpu" # or "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
By calling SamPredictor.set_image
the function, the input image is encoded, and SamPredictor
these encodings are used for subsequent target segmentation tasks.
predictor.set_image(image)
On the picture of the car above, select a point. The input format of the point is (x, y) and represents the label 1 (foreground point) or 0 (background point) attached to the output point . Multiple points can be entered, here we only use one point, and the selected point will be displayed as a five-pointed star mark.
Approach 1: Target segmentation using a single cue point
input_point = np.array([[500, 375]]) # 标记点
input_label = np.array([1]) # 点所对应的标签
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
SamPredictor.predict
Segment with , and the model will return the confidence corresponding to these segmentation targets.
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
Parameter Description:
point_coords: the coordinate point position of the prompt
point_labels: the type corresponding to the prompt point, 1 foreground, 0 background
boxes: the prompt box
multimask_output: multi-target output or but target output True or False
multimask_output=True
(default), the SAM model will output 3 segmentation targets and the corresponding confidence scores
. This setting is mainly used for cue points that face ambiguity, because a cue point may be inside multiple segmented targets, and multimask_output=True
all targets that contain the cue point can be segmented.
As shown in the following example: 2 kinds of car windows, and the whole car contain the prompt point of the five-pointed star.
masks.shape # (number_of_masks) x H x W
(3, 1200, 1800)
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {
i+1}, Score: {
score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
Method 2: Target segmentation using multiple cue points
A single cue point will often have ambiguous effects because multiple targets may contain the point. In order to get the single target we want, we can prompt multiple points on the target to obtain the segmentation result of the target.
For example, two prompt points are used on the truck below to directly extract the segmentation result of the entire car instead of the windows. This is the required setting multimask_output=False
for extracting single object segmentation results.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
masks.shape
(1, 1200, 1800)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
If we only want to get the segmentation result of the window, we can use the background point (label=0, the red five-pointed star in the figure below) to remove other parts of the car.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
Method 3: Use a box to specify a target for segmentation
The SAM model can take a box as input in the format [x1,y1,x2,y2]. To segment a single object, as shown below, the wheels of the car are segmented by the box.
input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
Method 4: Combine points and boxes for target segmentation
The following example: Remove the center hub of the tire to get only the outer part of the tire.
The box is used to get the tire; the point is marked as the background ( input_label = np.array([0])
), which plays a role in culling.
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
Method 5: Input multiple boxes at the same time for multi-target segmentation
Can be used to segment objects in different boxes by entering multiple boxes at the same time. The following is the segmentation effect on different objects of the car.
input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W
torch.Size([4, 1, 1200, 1800])
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
Summarize
The above is how to use the SAM model, and different segmentation results can be obtained through different prompting methods. Generally speaking, the effect is still very good, the key is that it can run quickly under the CPU environment. Interested friends, you can also try it yourself~
If the article is helpful to you, thank you for liking + paying attention!
Follow the business card below GZH: Axu Algorithm and Machine Learning, reply: [SAM] You can get SAM related codes, papers, pre-training models, usage documents, etc. Welcome to learn and communicate together