[CV Large Model SAM (Segment-Anything)] It is so powerful, how to use the SAM large model that splits everything: You can get the desired segmentation target through different prompts

This article mainly introduces how to use the SAM model: how to use different hints for target segmentation. Moreover, the model can run quickly in the CPU environment, which is really good~, hurry up and try it

Regarding the relevant code of the Segment-Anything model, the PDF of the paper, the pre-training model, the usage method, etc., I have packaged it up for the exchange and research of the small partners who need it. The way to obtain it is as follows :

Follow the business card GZH at the end of the article: Axu Algorithm and Machine Learning, reply: [SAM] You can get SAM related codes, papers, pre-trained models, usage documents, etc.
insert image description here
insert image description here

foreword

Recently, GPT has been hotly fired. I didn’t expect to see a large CV model so soon, and it has a new data set + new paradigm + super zero-sample generalization ability.
Although the CV large model that appeared this time is not as powerful as the GPT in NLP: one model can handle N multiple downstream tasks. But this is also a good start, and it should also be the future development trend of CV.
The emergence of SAM (Segment-Anything Model) unifies the downstream application of the segmentation task (a subset of the CV task), indicating that a large CV model is possible. It will definitely bring huge changes to the research of CV. Many tasks will be processed in a unified way. It may not be long before detection, segmentation and tracking will be all in one.

Project address: https://github.com/facebookresearch/segment-anything
Demo: https://segment-anything.com/

Install the operating environment

Running requires python>=3.8, and pytorch>=1.7 and torchvision>=0.8.
Install dependent libraries:

pip install git+https://github.com/facebookresearch/segment-anything.git

How to use the SAM model

Import related libraries and define display functions

The following imports the third-party libraries required to run, and defines the functions used to display points, boxes, and segmentation targets.

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

Import the image to be divided

image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

insert image description here

Object Segmentation Using Different Hint Methods

First, load the SAM pre-trained model. [ All the files have been packaged at the end of the article, and interested friends can obtain them by themselves ]

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "./models/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cpu"  # or  "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

By calling SamPredictor.set_imagethe function, the input image is encoded, and SamPredictorthese encodings are used for subsequent target segmentation tasks.

predictor.set_image(image)

On the picture of the car above, select a point. The input format of the point is (x, y) and represents the label 1 (foreground point) or 0 (background point) attached to the output point . Multiple points can be entered, here we only use one point, and the selected point will be displayed as a five-pointed star mark.

Approach 1: Target segmentation using a single cue point

input_point = np.array([[500, 375]])  # 标记点
input_label = np.array([1])  # 点所对应的标签
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

insert image description here
SamPredictor.predictSegment with , and the model will return the confidence corresponding to these segmentation targets.

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

Parameter Description:

point_coords: the coordinate point position of the prompt
point_labels: the type corresponding to the prompt point, 1 foreground, 0 background
boxes: the prompt box
multimask_output: multi-target output or but target output True or False

multimask_output=True(default), the SAM model will output 3 segmentation targets and the corresponding confidence scores. This setting is mainly used for cue points that face ambiguity, because a cue point may be inside multiple segmented targets, and multimask_output=Trueall targets that contain the cue point can be segmented.
As shown in the following example: 2 kinds of car windows, and the whole car contain the prompt point of the five-pointed star.

masks.shape  # (number_of_masks) x H x W
(3, 1200, 1800)
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {
      
      i+1}, Score: {
      
      score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

insert image description here
insert image description here
insert image description here

Method 2: Target segmentation using multiple cue points

A single cue point will often have ambiguous effects because multiple targets may contain the point. In order to get the single target we want, we can prompt multiple points on the target to obtain the segmentation result of the target.
For example, two prompt points are used on the truck below to directly extract the segmentation result of the entire car instead of the windows. This is the required setting multimask_output=Falsefor extracting single object segmentation results.

input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
masks.shape
(1, 1200, 1800)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 

insert image description here

If we only want to get the segmentation result of the window, we can use the background point (label=0, the red five-pointed star in the figure below) to remove other parts of the car.

input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 

insert image description here

Method 3: Use a box to specify a target for segmentation

The SAM model can take a box as input in the format [x1,y1,x2,y2]. To segment a single object, as shown below, the wheels of the car are segmented by the box.

input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()


insert image description here

Method 4: Combine points and boxes for target segmentation

The following example: Remove the center hub of the tire to get only the outer part of the tire.
The box is used to get the tire; the point is marked as the background ( input_label = np.array([0])), which plays a role in culling.

input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

insert image description here

Method 5: Input multiple boxes at the same time for multi-target segmentation

Can be used to segment objects in different boxes by entering multiple boxes at the same time. The following is the segmentation effect on different objects of the car.

input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)
masks.shape  # (batch_size) x (num_predicted_masks_per_input) x H x W
torch.Size([4, 1, 1200, 1800])
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

insert image description here

Summarize

The above is how to use the SAM model, and different segmentation results can be obtained through different prompting methods. Generally speaking, the effect is still very good, the key is that it can run quickly under the CPU environment. Interested friends, you can also try it yourself~

If the article is helpful to you, thank you for liking + paying attention!

Follow the business card below GZH: Axu Algorithm and Machine Learning, reply: [SAM] You can get SAM related codes, papers, pre-training models, usage documents, etc. Welcome to learn and communicate together

Guess you like

Origin blog.csdn.net/qq_42589613/article/details/130061434