SAM【1】:Segment Anything


foreword

Segment Anthing is Meta's first open-source segmented large model, which has recently set off a wave of large-scale models in the CV field. In just a few days, various secondary innovations and evaluations emerged one after another. At the same time, Meta released the Demo of the model , allowing researchers to experience SAMthe magic and power of .

The visual large model is similar to the natural language large model, and its main purpose is to solve all the problems of users through one model. Restricted by the wider picture types and tasks of image data, SAM currently mainly solves the most traditional and widely used segmentation tasks. SAMBy introducing the NLP prompt paradigm into the CV field, it provides wider support and in-depth research for the CV basic model; by constructing a suitable prompt, the ability to zero-shot new samples can be realized, and sometimes even the model can be achieved. Tasks not considered in design.

This article mainly analyzes the method of SAM, and also lays a good foundation for the subsequent study of large models. If you only want to understand SAMthe model architecture and methods, you can directly read Section 2.2 of this article

Original paper link: Segment Anything


1. Abstraction & Introduction

1.1. Abstraction

This paper proposes new tasks, models, and datasets for image segmentation. The design and training of the model is flexible so that it can transfer zero-shot (zero-shot) to new image distributions and tasks. Experiments evaluate its capabilities on a number of tasks and find its zero-shot performance to be impressive—often competitive or even better than previous fully supervised results.

1.2. Introduction

The large language model pre-trained on the network dataset has powerful zero-shot and few-shot generalization capabilities. These basic models can be extended to tasks and data distributions beyond the training process. This capability is achieved prompt engineeringthrough

This basic model has also been explored in vision tasks, such as CLIPusing ALIGNcontrastive learning to align text and image encoding, and generate image encoders through prompts, which can be extended to downstream tasks, such as generating images

The purpose of this research is to develop a promptable model, which is pre-trained on a large data set through specific tasks, so that it has strong generalization, that is, it can be prompted (prompt) Solve a series of downstream segmentation tasks on new datasets

In order to achieve the above goals, this paper proposes three problems that need to be solved:

  • What kind of tasks can be zero-shotgeneralized?
  • What is the corresponding network structure?
  • What kind of datasets can drive such tasks and models?

insert image description here

In summary, this paper proposes the following solutions and explores some other related issues:

  • Task
    • Create a promptable segmentation task, so that for any form of segmentation prompts such as point , box , mask , text (not yet implemented), an effective segmentation mask can be returned
    • Even if the input prompts are ambiguous, the model can output more reasonable segmentation results
    • prompt engineering
      • Hint engineering refers to the process of designing hints that can help solve specific downstream segmentation tasks
      • By using the knowledge gained in pre-training on hintable segmentation tasks, hints can be designed to guide the model to generate effective segmentation masks for specific objects or regions in images
  • Model
    • Starting from the task requirements, the model needs to meet the following three points:
      • Can support flexible prompt information
      • Ability to calculate masks in real time for interactive use
      • Have a sense of ambiguity.
    • The paper proposes a model architecture that meets the above three requirements: the model needs to support flexible prompts and be able to calculate interactively generated masks in real time, so the author designed an image encoder and a fast prompt encoder , and then passed a lightweight prompt The encoder combines and predicts the output segmentation mask
  • Data Engine & Dataset
    • A strong generalization model requires a large-scale data set with rich diversity. The paper builds a Data Engine to make up for the lack of image mask, which is divided into three steps:
      • Human assistance (help labeling, similar to interactive segmentation)
      • Semi-automatic (automatically generate object masks by providing hints)
      • Fully automatic (automatically generated by using the regular grid as a prompt)
    • Newly constructed dataset SA-1B , including more than 11 million images and 1 billion masks, more than 400 times the size of existing datasets (open access)

2. Segment Anything Model

2.1. Segment Anything Task

2.1.1. Task

promptCan be a set of foreground/background points , a rough box or mask , free-form text (any information indicating what to segment in the image), returns a valid segmentation mask as prompted. Efficient means that, even if the user's promptis ambiguous, the model can output multiple plausible segmentation masks for the user to choose from.

This task leads to a natural pre-training algorithm and a general method for zero-shot transfer to downstream segmentation tasks via hints.

insert image description here

2.1.2. Pre-training

This paper draws inspiration from Interactive Segmentation

Interactive Segmentation : Interactive segmentation refers to a typical computer vision task in which an algorithm is trained to segment an image into different regions or objects based on user input. This means the algorithm is able to take hints or cues from the user to refine its segmentation results. In other words, users can interact with the algorithm and guide it to obtain more accurate segmentation results.

SAMYou need to pre-train the model with a set of cues (points, bounding boxes, masks, or text, etc.) and compare the model output with the real one. Unlike interactive segmentation, this task predicts an effective mask for any cue, thus requiring a specific choice of model and training loss function.

2.1.3. Zero-shot transfer

The pre-training task endows the model with the ability to respond appropriately to any cues at inference time, so downstream tasks can be solved by engineering appropriate cues.

SAMCan respond to any prompt, so a downstream task can be transformed into a task that design prompts

2.2. Segment Anything Model Methods

insert image description here

2.2.1. Image Encoder

This paper uses a MAEpre-trained visual transformer (ViT), which is minimized to handle high-resolution input. For a detailed explanation of ViT, please refer to my other blog: CV-Model [6]: Vision Transformer

This image encoder runs once per image and can be applied before hinting the model. According to the size of the image encoder parameters, the pre-training model weights can be divided into: vit-h, vit-l, vit-b from large to small.

SAMThe image encoder in adopts the standard As the image encoder, the original image obtains a size of 1024 × 1024 1024 \times 1024ViT through proportional scaling and short-side padding operations1024×1024 input images. Then use the kernel size as16 1616 , the stride is16 16The convolution of 16 discretizes the image to64 × 64 × 768 ( W , H , C ) 64 \times 64 \times 768 (W, H, C)64×64×768(W,H,C ) vector (image embedding). vector inWWW andCCC is sequentially flattened and then entered into multiple layersTransformer Encoder. In order to reduce the channel dimension,ViTthe output vector passes through two layers of convolution (kernels are1 11 and3 33 , each layer output accessLayer_norm2d) compressed to a feature dimension of256 256256

The implementation code is as follows:

self.neck = nn.Sequential(
    nn.Conv2d(
        embed_dim,
        out_chans,
        kernel_size=1,
        bias=False,
    ),
    LayerNorm2d(out_chans),
    nn.Conv2d(
        out_chans,
        out_chans,
        kernel_size=3,
        padding=1,
        bias=False,
    ),
    LayerNorm2d(out_chans),
)

MAE is a scalable self-supervised learning method for computer vision. After covering 95% of the pixels, it can still restore the outline of the object. The implementation method: first mask the random part of the input image, and then reconstruct the lost pixel

insert image description here

In MAE, the original image is ViTcut into non-overlapping patches, and some patches are reserved to enter ViTthe encoder of the architecture to learn the representation of the patch, the learned patch representation and the representation of the mask (gray) (all masks use a unified embedding, but the pos embedding is different) ) is input to the decoder of the architecture according to the original patch order ViTto obtain the restored image. The loss is before and after partial restoration of the mask l2_loss. After training, we only use the encoder to extract image features.

It can be seen from the original model that the representation embedding of the image is unchanged, so different prompt inputs can be performed on the encoded image embedding multiple times to obtain the desired result, which is very useful for interactive segmentation scenarios.

2.2.2. Prompt Encoder

Based on the split task requirements, the prompts supported by SAM can be divided into the following two categories:

Sparse class (sparse prompt)

Contains point , bbox , free tex t

A point is represented as the sum of the point's positional encoding and one of two learned embeddings that indicate whether the point is in the foreground or background

Position encoding assigns each position in the image a unique vector of numbers that encodes its position. These vectors are then combined with other features of the image, such as color or texture, to create a representation that the network can use to predict. They are essentially a set of weights learned during training that enable the network to more efficiently map inputs to outputs. In this case, the authors use learned embeddings to represent different types of cues

If you provide a point as prompt, because the original image may consist of multiple parts, this point will belong to multiple parts. In this case, three mask results will be returned by default (all, part, sub-part); if more than one is provided points as a prompt, the model will read in the given points in turn, and select the prediction with the highest score from the last three mask results as the prompt for the next prediction

The specific encoding method of the point is as follows:

point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
# self.not_a_point_embed为待学习的embedding
point_embedding[labels == -1] += self.not_a_point_embed.weight
# self.point_embeddings为待学习的embedding
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight

A box is represented by a pair of embeddings (a dot in the upper left corner of the box and a dot in the lower right corner) :

  • The positional encoding of the upper left corner is summed with the learned embedding representing the upper left corner
  • The positional encoding of the bottom right corner is summed with the learned embedding representing the bottom right corner

The specific coding method of box is as follows:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    """Embeds box prompts."""
    boxes = boxes + 0.5  # Shift to center of pixel
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    return corner_embedding

To represent free text, we use CLIPthe text encoder from (in general, any text encoder is possible)

This part of the open source code of Text prompt is not involved. The practice mentioned in the paper is as follows:

insert image description here

  1. The pre-trained text encoder is used CLIP (ViT-L/14@336px)as the text encoder, and the image encoder is used as the image encoder to replace the SAM image encoder (the ViT-L/14@336pxoutput feature dimension is 768, while the feature dimension of point and bbox is 256, so there is still a full connection for feature Dimension alignment), the text feature vector and image feature vector are prepared l2 normfor the next step
  2. Construct the training data so that the text embedding and image embedding generated in the previous step are aligned in the mask decoder module
    1. Take out the pictures generated in the second stage of Data Engine (the accuracy of labeling at this stage is high, which will be mentioned later), these pictures have corresponding text descriptions, and the description text is passed to get CLIPtext embedding
    2. Randomize the minimum bounding rectangle of the main body of the mask 1 − 2 1-21After 2 times of external expansion and cropping, and scaled tothe image input as (will filter336pxtheimage with the smallest circumscribed rectangle smaller than )CLIP100px
    3. In order to enhance the ability to extract the main features of the image, the step 2 22 image expansion part with50 50%50 probability with0 00 instead, if this strategy is adopted,ViTthe last layer will also mask out these are filled with0 0Features at the position of 0
    4. Go through step 2 22 and3 33 Get the picture andCLIPgetthe image embedding
  3. In the reasoning stage, the text directly uses CLIPthe original text encoder of the first step (it should be noted that the text does not specify whether the image encoder uses CLIPthe image encoder or the previous MAEpre-training ViT)

Dense class (dense prompt)

contains mask

Dense hints (i.e. masks) have a spatial correspondence with the image

by 4 lower than the input image 44 times the resolution input mask, and then use two dimensions of2 × 2 2 \times 22×2. The span is2 2Convolution of 2 and downsampling4 44 times, the output channels are4 44 and16 1616 . Finally use a1 × 1 1 \times 11×1 convolution maps the channel dimension to256 256256 . Each layer is separated byGELUan activation function and layer normalization. Then, masks and image embeddings are added element by element.

If no coarse segmentation input is provided, the default learnable embedding will be used to represent the features of the empty segmentation prompt.

The specific encoding method of mask is as follows:

self.mask_downscaling = nn.Sequential(
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans // 4),
    activation(),
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans),
    activation(),
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)

2.2.3. Mask Decoder

The core of the mask decoder is to use the transformer to learn and prompt aligned image embedding and additional 4 4Embedding of 4 tokens. These 4 token embeddings are the iou token embedding and the embedding of the 3 segmentation result tokens. The token embedding learned by the transformer will be used in the final task header to obtain the target result

The transformer has 3 inputs:

  • token embedding
    • sum of prompt tokens embedding and output tokens embedding
# iou_token 1个;mask_tokens为4个,分别是3个输出结果对应的token,和一个分割sparse embedding的token
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
# BX(num_point+2*bbox+5) X256
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  • src
    • sum of image embedding and dense prompt embedding
# Expand per-image data in batch direction to be per-mask
# 对每一个token 都需要一个一样的image embedding
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  • pos_src
    • The position code of the image, note that the position code here is similar DETR, it is a two-dimensional code, xxxyyThe y direction is coded separately and then spliced
    • Instead of the traditional ViTone-dimensional coding of the patch, it will lose yyInformation about the y- axis direction

insert image description here

The specific implementation process is:

  • Insert a learnable token into prompt embeddings for decoder output
    1. Prompt tokens + output tokens for self attn
    2. Use the obtained token and image embedding to perform cross attn (token as Q)
    3. point-wise MLP update token
    4. Use image embedding and the token in step 3 for cross attn (image embedding as Q)
  • Repeat above steps 2 22 times, and thenattnconnect through the residual, and finally output masks and iou scores
def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
    # Self attention block
    if self.skip_first_layer_pe:
        queries = self.self_attn(q=queries, k=queries, v=queries)
    else:
        q = queries + query_pe
        attn_out = self.self_attn(q=q, k=q, v=queries)
        queries = queries + attn_out
    queries = self.norm1(queries)
    # Cross attention block, tokens attending to image embedding
    # query 为token embedding,会随着前向发生变化,query pe为最原始的token embedding
    q = queries + query_pe
    # keys 为src,key pe 为image pos embedding
    k = keys + key_pe
    attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm2(queries)
    # MLP block
    mlp_out = self.mlp(queries)
    queries = queries + mlp_out
    queries = self.norm3(queries)
    # Cross attention block, image embedding attending to tokens
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
    keys = keys + attn_out
    keys = self.norm4(keys)
    return queries, keys

3 returned by transformer 3The embedding of 3 mask tokens passes through3 3After 3 layers of mlp, superimposed with the aligned image embedding to get3 33 final segmentation results; iou token gets3 33 segmentation result confidence scores

upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
    hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

2.2.4. Losses and training

The model loss function is a linear combination of focal lossand dice lossto avoid the impact of class imbalance or data noise

2.3. Segment Anything Data Engine

The design idea of ​​this article is LLMsimilar to that of , mainly to increase the capacity of the model. Under this premise, massive training data is crucial to the effect of the model. However, unlike natural language or other image tasks, the segmentation task cannot be achieved through self-supervision from the original image, and segmentation and labeling is an extremely costly task. So this paper designs 3 stages to generate training data

2.3.1. Assisted-manual stage

SAMAnnotate and optimize through the interactive annotation tool based on , without assigning label information to the mask when annotating.

In this stage, SAM is first trained through common public segmentation data sets, providing imprecise mask information, and optimizing the mask, and then only using the newly generated labeled data after optimization for retraining. When labeling, manually click on the foreground points and background points as SAMthe prompt input to label and correct the segmentation results. With the increase of labeled data, the newly labeled data will be used to retrain the SAMmodel. At this stage, the model has been retrained repeatedly for 6 Second-rate.

2.3.2. Semi-automatic stage

Firstly, the significant targets are automatically detected, and then the unmarked targets are manually corrected to increase the diversity of samples.

Use the detection frame as SAMthe prompt input (target detection is much less difficult than segmentation), and in the output segmentation results, humans only need to pay attention to the segmentation images with low confidence scores for correction, and supplement the missing results SAM. Also at this stage, with the increase of labeled data, SAMthe model will continue to retrain, and a total of 5 trainings will be performed.

2.3.3. Fully automatic stage

The third stage is similar to our process of generating pseudo-label training, using the previous data trained SAMto generate segmentation results on massive data, and then filtering out some possible wrong results through rules. The specific process is as follows:

  • Generating ( 32 , 32 ) for images ( 32 , 32 )(32,32 ) grid points, and for each point predict a set of masks that likely correspond to valid objects
    • If a point falls on a subpart, part, the model returns the object for that subpart, part and whole
  • iouObtain a mask with high confidence by predicting the confident of the screening mask
    • Select a stable mask (stable mask, in a similar mask, the probability threshold is 0.5 − δ 0.5 - \delta0.5δ0.5 + δ 0.5 + \delta0.5+between δ )
  • NMSFilter duplicate masks in confident and stable by
    • When an object is detected in an image, it may be detected multiple times due to differences in the detection algorithm or the appearance of the object. NMS is used to remove these redundant detections, keeping only the most accurate detections
    • The main idea of ​​NMS is to select the best object box based on confidence and overlap. Specifically, the implementation process of the NMS algorithm is as follows:
      1. For each category, sort all object boxes from largest to smallest according to confidence
      2. Select the target box with the highest confidence and add it to the final result set
      3. Calculate the overlap between the remaining target frames and the selected target frame (usually using the IoU algorithm), and delete the target frame with an overlap greater than a certain threshold
      4. Repeat the above process until all target boxes are processed.

3. Demo

  • Positive point
    insert image description here
    insert image description here
  • Negative point
    insert image description here
  • Box
    insert image description here
  • Everything
    insert image description here

Summarize

The main contribution of this paper is to build a very large-scale, high-quality segmentation dataset and a model with strong generalization that supports prompting tasks, and has the following characteristics:

  • This model can be used as a foundation model for computer vision and for downstream tasks
  • Enables strong integrability by creating SAMinterfaces with other componentsSAM
  • SAMIt has generalization and versatility, and can process prompt information in real time

Reference blog
Reference blog

Guess you like

Origin blog.csdn.net/HoraceYan/article/details/130420571