Segment Anything使用手册(交互式数据标柱|自动数据标柱)

主要内容包含segment-anything项目的安装、基于SamPredictor对单点输入生成mask、基于SamPredictor对多点输入生成mask、基于SamAutomaticMaskGenerator自动生成mask。

Segment Anything项目是一个可以对任何图像进行分割的项目,其论文介绍可以查看https://blog.csdn.net/a486259/article/details/131137939,其测试网站为 https://segment-anything.com

这里对Segment Anything项目的使用进行初步总结,绝大部分内容源自https://github.com/facebookresearch/segment-anything 。

注:segment-anything训练VIT模型时的输入size为1024x1024,其输出的feature size为256x64x64,进行了16倍的下采样

1、安装segment-anything

下载segment-anything项目,进入目录后执行pip install -e .安装项目。

git clone [email protected]:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .

该项目依赖opencv-python pycocotools matplotlib onnxruntime onnx torch等包,安装命令如下

pip install opencv-python pycocotools matplotlib onnxruntime onnx torch

segment-anything模型是基于torch框架实现的

2. 根据提示输入生成mask

Segment Anything Model (SAM) 预测对象mask,给出所需识别出对象的提示输入(对象的粗略位置信息)。该模型首先将图像转换为图像嵌入,然后解码器根据用户输入的提示(粗略位置信息)可以生成高质量的掩模。
SamPredictor类为模型调用提供了一个简单的接口,用于提示模型的输入。它先让用户使用“set_image”方法设置图像,该方法会将图像输入转换到特征空间嵌入。然后,可以通过“predict”方法输入提示信息,以根据这些提示有效地预测掩码。predict函数支持将点和框提示以及上一次预测迭代中的mask作为输入。

2.1 前置函数库

前置库导入和函数实现

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

2.2 显示样图

读取图片并展示

image = cv2.imread('images/truck.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

2.3 加载SAM模型

sam_vit_b_01ec64模型的下载地址为: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
这里需要注意要使用cuda

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)
#predictor.set_image(image)

其它版本的模型下载地址为

通过调用“SamPredictor.set_image”处理图像以生成图像嵌入(特征向量)。“SamPrejector”会记住此特征向量,并将其用于后续掩码预测。

predictor.set_image(image)

2.4 单点输入生成mask

要选择卡车,可以卡车上选择一个点。点以(x,y)格式输入到模型中,并带有标签1(前景点)或0(背景点)。可以输入多个点;这里我们只使用一个。所选的点将在图像上显示为星形。
此时代码及执行效果如下:

input_point = np.array([[250, 187]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

使用“SamPredictor.prdict”进行预测。该模型返回掩码(masks)、掩码的分数(scores)以及可传递到下一次预测迭代的低分辨率掩码(logits)。

在“multimask_output=True”(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。此设置用于存在不明确输入提示的时候(光凭一个点无法有效识别出用户意图是组件局部、组件还是整体),并帮助模型消除与提示一致的不同对象的歧义。当为“multimask_output=False”时,它将返回一个掩码。对于单点等不明确的提示,建议使用“multimask_output=True”,即使只需要一个掩码;可以通过选择在“分数”中返回的分数最高的一个来选择最佳的单个掩码。这通常会得到更好的mask。

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

print(masks.shape)  # (number_of_masks) x H x W  | output (3, 600, 900)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {
      
      i+1}, Score: {
      
      score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  



3、多输入生成mask

3.1 多点输入生成mask

单个输入点不明确,需要让模型返回了与其一致的多个对象。要获得单个对象,可以提供多个点。如果可用,还可以将先前迭代的掩码(logits值)提供给模型以帮助预测。当使用多个提示指定单个对象时,可以通过设置“multimask_output=False”来请求获取单个掩码。

input_point = np.array([[250, 184], [562, 322]])
input_label = np.array([1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
print(masks.shape) #output: (1, 600, 900)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 

输入负点明确区域

input_label与input_point相对应,为0时表示是负点

input_point = np.array([[250, 187], [561, 322]])
input_label = np.array([1, 0])#为0时表示是负点,即第二个点[561, 322]是负点

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 

3.2 boxes输入生成mask

支持将xyxy格式的box作为输入,将框内的主体目标识别出来(类似于实例分割)

input_box = np.array([212, 300, 350, 437])
masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

3.3 同时输入点与boxes生成mask

point和boxes可以同时输入,只需将这两种类型的提示都包括在预测器中即可。在这里,这可以用来只选择卡车的轮胎(将车轴部分设置为负点),而不是整个车轮。

input_box = np.array([215, 310, 350, 430]) #只能默认框住正类
input_point = np.array([[287, 375]]) 
input_label = np.array([0]) #将车轴部分设置为负点

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

3.4 同时输入多个boxes生成mask

SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。

input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device) #假设这是目标检测的预测结果
input_boxes=input_boxes/2

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

print(masks.shape)  # (batch_size) x (num_predicted_masks_per_input) x H x W | output: torch.Size([4, 1, 600, 900])

plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

3.5 端到端的批量推理

如果所有提示输入都已经明确的,则可以以端到端的方式直接运行SAM。这允许SAM对图像进行批处理,以下代码构建了2个image和boxes。

image1 = cv2.imread('images/truck.jpg')
image1_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=sam.device)

image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
    [450, 170, 520, 350],
    [350, 190, 450, 350],
    [500, 170, 580, 350],
    [580, 170, 640, 350],
], device=sam.device)

图像和提示都作为PyTorch张量输入,这些张量(图像和提示输入)已经被编码为特征向量。所有的输入数据都被封装为list,每个元素都是一个dict,它的key如下:

  • image: CHW格式的PyTorch tensor .
  • original_size: 图像原始大小, (H, W) format.
  • point_coords: 一批输入点格式.
  • point_labels: 每个输入点所对应的类型(正例或负例).
  • boxes: 一批输入的boxe(只能是正例).
  • mask_inputs: 一批输入的mask.

如果没有相应的信息,可以不进行输入,但image必须输入

from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

def prepare_image(image, transform, device):
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device.device) 
    return image.permute(2, 0, 1).contiguous()

batched_input = [
     {
    
    
         'image': prepare_image(image1, resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
         'original_size': image1.shape[:2]
     },
     {
    
    
         'image': prepare_image(image2, resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
         'original_size': image2.shape[:2]
     }
]
batched_output = sam(batched_input, multimask_output=False)
print(batched_output[0].keys()) # output:dict_keys(['masks', 'iou_predictions', 'low_res_logits'])

输出是每个输入图像的结果列表,其中元素是字典对象,其key为:

  • masks: 一批mask,tensor张量
  • iou_predictions: 与mask相对应的iou预测值.
  • low_res_logits: 每个掩码的低分辨率logits,可以在以后的迭代中作为掩码输入再次调用模型。
fig, ax = plt.subplots(1, 2, figsize=(20, 20))

ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
    show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
    show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')

ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
    show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
    show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')

plt.tight_layout()
plt.show()

4、自动生成mask

4.1 基础前置库

这里加载了一些基础库,并读取images/dog.jpg作为样例数据

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

image = cv2.imread('images/dog.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

4.2 自动生成mask

要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)

masks = mask_generator.generate(image)

print(len(masks))
print(masks[0].keys())
print(masks[0])

plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

代码输出的文字信息如下:

42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{
    
    'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}

所生成的图像如下

masks = mask_generator.generate(image)

Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:

  • segmentation : np的二维数组,为二值的mask图片
  • area : mask的像素面积
  • bbox : mask的外接矩形框,为XYWH格式
  • predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
  • point_coords : 用于生成该mask的point输入
  • stability_score : mask质量的附加指标
  • crop_box : 用于以XYWH格式生成此遮罩的图像裁剪

4.3 自动mask的参数

在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,#控制采样点的间隔,值越小,采样点越密集
    pred_iou_thresh=0.86,#mask的iou阈值
    stability_score_thresh=0.92,#mask的稳定性阈值
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=50,  #最小mask面积,会使用opencv滤除掉小面积的区域
)
masks2 = mask_generator_2.generate(image)

print(len(masks2)) # 69

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

猜你喜欢

转载自blog.csdn.net/a486259/article/details/131194434
今日推荐