A big change in image segmentation: from SAM (segment everything) to FastSAM, MobileSAM

foreword

SAM is a general model for processing image segmentation tasks. Unlike previous image segmentation models that can only handle certain types of images, SAM can handle all types of images.

Before the emergence of SAM, basically all image segmentation models were proprietary models. For example, in the medical field, there are artificial intelligence models that specialize in segmenting MRI images, and there are also artificial intelligence models that specialize in segmenting CT images. But these models often only have good performance when segmenting images in the exclusive domain, and often perform poorly when segmenting images in other domains.

After following the first two articles, this article talks about the three underlined image segmentation models below

January March April May June August October November
2020 DETR DDPM

NO

VisionTransformer 

2021

CLIP

GIVE HER

SwinTransformer

MAE

SwinTransformerV2

2022 BLIP FROM E 2

StableDiffusion 

BEiT-3

Midjourney V3

2023 BLIP2

VisualChatGPT 

GPT4

Midjourney V5

SAM(Segment Anything Model)

FastSAM

(Chinese Academy of Sciences version SAM)

MobileSAM

The first part SAM (Segment Anything Model)

1.1 SAM (Segment Everything): Establish a general segmentation model and flexibly segment according to prompts

  • The large language model pre-trained on the network dataset has strong zero-shot (zero sample) and few-shot (few-shot) generalization capabilities, and these "basic models" can be extended to tasks and data beyond the training process distribution, this ability is realized through "prompt engineering", specifically, inputting prompts to obtain effective text output, after scaling and training using a large number of text databases on the network, it is found that this zero-sample and few-sample training model is more efficient than The effect of fine-tuning the model is even better, the larger the data set, the more obvious the effect, such as GPT3
  • This basic model has also been explored in vision tasks. For example, CLIP and ALIGN use contrastive learning to align text and image encoding, and generate image encoders through prompts, which can be extended to downstream tasks, such as image generation.

The purpose of SAM ( paper address , code address ) is to establish a basic model of image segmentation and develop a model with prompting ability

3 questions to solve:

  1. What tasks can achieve zero-sample?
    Through the prompt input, an effective mask is generated. When the prompt is uncertain, multiple objects can be generated (such as a point on the clothes, which can represent both the clothes and the person wearing the clothes), as shown in the following figure: the prompt can be point, rectangle, text, mask, or image

  2. What should the model structure look like?
    The model should support flexible prompts, and generate masks in real time, and the output is also fuzzy (for example, indicating clothes or people wearing clothes). The design structure is as follows: a prompt encoder to encode the prompts, an image encoder to encode the image, and finally
    \rightarrow  generate
    \rightarrow  an
    \rightarrow  embedding Combine 2 encoders, connect a lightweight mask decoder, and output the final mask

  3. How does the data support these tasks? 
    A large and diverse mask data is required. Natural language data is obtained online, but mask data is insufficient and an alternative strategy is needed.
    The solution is to build a "data engine", which is divided into 3 steps:
    \rightarrow  manual assistance (help labeling, similar to interactive segmentation)
    \rightarrow  semi-automatic (by providing prompts, automatically generate object masks)
    \rightarrow  fully automatic (by using regular grids as prompts to automatically generate)
    as follows As shown in the figure: first label the data to train the model, and then use the model to assist in labeling the data, thus establishing a data cycle. In the end, 1 billion masks were generated from 11 million images, which is currently the largest data set, with 400 times more masks than the current existing data set

1.2 Model structure ( image encoder + prompt encoder + mask decoder) and training

The model structure is as follows

1. 2.1 The composition of image encoder (ViT) and its encoding implementation

Use the visual Transformer pre-trained by MAE (ViT, if you forget what ViT looks like, you can review part 4 of this article ), at least adapt to high-resolution input, the encoder only runs on each image before the prompt encoder Input the image of (c, h, w) once
, scale the image, and scale it to 1024 according to the long side, and pad if the short side is not enough, and get the image of (c, 1024, 1024), and get 16 times the image through the image encoder The downsampled feature has a size of (256,64,64)

As for its code implementation, it mainly implements the following classes

  1. One is to define the ImageEncoderViT class, which is an image encoder based on Vision Transformer, which inherits from nn.Module
    import torch  
    import torch.nn as nn  
    import torch.nn.functional as F  
    from typing import Optional, Tuple, Type  
    
    # 导入.common模块中的LayerNorm2d和MLPBlock
    from .common import LayerNorm2d, MLPBlock  
    
    # 定义ImageEncoderViT类,这是一个基于Vision Transformer的图像编码器,该类从nn.Module继承
    class ImageEncoderViT(nn.Module):  
        # 类的构造函数,定义了一系列的参数,例如图像大小,块大小,输入通道数,嵌入维度,Transformer的深度,注意力头部数等。
        def __init__(  
            self,
            img_size: int = 1024,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            depth: int = 12,
            num_heads: int = 12,
            mlp_ratio: float = 4.0,
            out_chans: int = 256,
            qkv_bias: bool = True,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            act_layer: Type[nn.Module] = nn.GELU,
            use_abs_pos: bool = True,
            use_rel_pos: bool = False,
            rel_pos_zero_init: bool = True,
            window_size: int = 0,
            global_attn_indexes: Tuple[int, ...] = (),
        ) -> None:
    
            # 使用super函数调用父类的初始化函数
            super().__init__()  
            # 将图像大小保存为类的一个属性
            self.img_size = img_size  
    
            # 创建PatchEmbed实例,用于将输入图像划分为多个patch,并将每个patch嵌入到一个向量空间中
            self.patch_embed = PatchEmbed(  
                kernel_size=(patch_size, patch_size),
                stride=(patch_size, patch_size),
                in_chans=in_chans,
                embed_dim=embed_dim,
            )
    
            # 创建位置嵌入属性,如果使用绝对位置嵌入,则初始化这个属性
            self.pos_embed: Optional[nn.Parameter] = None
            if use_abs_pos:
                self.pos_embed = nn.Parameter(
                    torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
                )
    
            # 创建Transformer的主体,包含多个Transformer block
            self.blocks = nn.ModuleList()  
            for i in range(depth):
                block = Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    use_rel_pos=use_rel_pos,
                    rel_pos_zero_init=rel_pos_zero_init,
                    window_size=window_size if i not in global_attn_indexes else 0,
                    input_size=(img_size // patch_size, img_size // patch_size),
                )
                self.blocks.append(block)
    
            # 创建neck属性,包含一个卷积层,一个LayerNorm层,另一个卷积层和另一个LayerNorm层
            self.neck = nn.Sequential(
                nn.Conv2d(
                    embed_dim,
                    out_chans,
                    kernel_size=1,
                    bias=False,
                ),
                LayerNorm2d(out_chans),
                nn.Conv2d(
                    out_chans,
                    out_chans,
                    kernel_size=3,
                    padding=1,
                    bias=False,
                ),
                LayerNorm2d(out_chans),
            )
    
        # 前向传播函数
        def forward(self, x: torch.Tensor) -> torch.Tensor:  
            # 对输入x进行patch embedding
            x = self.patch_embed(x)  
            # 如果使用了位置嵌入,将位置嵌入加到x上
            if self.pos_embed is not None:
                x = x + self.pos_embed
    
            # 将x通过所有的Transformer block
            for blk in self.blocks:  
                x = blk(x)
    
            # 将x通过neck,得到最终的输出
            x = self.neck(x.permute(0, 3, 1, 2))  
    
            return x
  2. Define the Block class, which is the basic component module of Transformer, including attention mechanism and feedforward neural network, which inherits from nn.Module
    # 定义Block类,这是Transformer的基本组成模块,包括注意力机制和前馈神经网络。该类从nn.Module继承
    class Block(nn.Module):  
        # 类的构造函数,定义了一系列的参数,例如输入通道数,注意力头部数,mlp隐藏层与嵌入层的比例,是否添加偏置到查询,键,值,归一化层,激活函数等。
        def __init__(  
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.0,
            qkv_bias: bool = True,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            act_layer: Type[nn.Module] = nn.GELU,
            use_rel_pos: bool = False,
            rel_pos_zero_init: bool = True,
            window_size: int = 0,
            input_size: Optional[Tuple[int, int]] = None,
        ) -> None:
    
            # 使用super函数调用父类的初始化函数
            super().__init__()  
            # 创建第一个归一化层
            self.norm1 = norm_layer(dim)
            # 创建注意力机制层
            self.attn = Attention(  
                dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                use_rel_pos=use_rel_pos,
                rel_pos_zero_init=rel_pos_zero_init,
                input_size=input_size if window_size == 0 else (window_size, window_size),
            )
    
            # 创建第二个归一化层
            self.norm2 = norm_layer(dim)
            # 创建MLP层
            self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
    
            # 定义窗口大小
            self.window_size = window_size
    
        # 前向传播函数
        def forward(self, x: torch.Tensor) -> torch.Tensor:  
            # 保存输入x,以便稍后进行残差连接
            shortcut = x  
            # 对x进行第一次归一化处理
            x = self.norm1(x)
            # 如果定义了窗口大小,则对x进行窗口划分
            if self.window_size > 0:  
                H, W = x.shape[1], x.shape[2]
                x, pad_hw = window_partition(x, self.window_size)
    
            # 对x进行注意力处理
            x = self.attn(x)  
            # 如果定义了窗口大小,则对x进行窗口合并
            if self.window_size > 0:  
                x = window_unpartition(x, self.window_size, pad_hw, (H, W))
    
            # 对x进行残差连接
            x = shortcut + x  
            # 对x进行第二次归一化处理并通过MLP层,然后进行第二次残差连接
            x = x + self.mlp(self.norm2(x))  
    
            return x
  3. Define the Attention class, which is a multi-head attention mechanism block that supports relative position embedding. This class inherits from nn.Module
    # 定义Attention类,这是一个多头注意力机制的块,支持相对位置嵌入,该类从nn.Module继承
    class Attention(nn.Module):  
        # 类的构造函数,定义了一系列的参数,例如输入通道数,注意力头部数,是否添加偏置到查询,键,值,是否使用相对位置嵌入等。
        def __init__(  
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = True,
            use_rel_pos: bool = False,
            rel_pos_zero_init: bool = True,
            input_size: Optional[Tuple[int, int]] = None,
        ) -> None:
    
            # 使用super函数调用父类的初始化函数
            super().__init__()  
    
            # 保存注意力头部数
            self.num_heads = num_heads  
    
            # 计算每个注意力头部的维度
            head_dim = dim // num_heads  
    
            # 缩放因子
            self.scale = head_dim**-0.5  
    
            # 创建线性变换层,用于生成查询、键、值
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    
            # 创建线性变换层,用于将注意力加权后的值进行线性变换
            self.proj = nn.Linear(dim, dim)
    
            # 是否使用相对位置嵌入
            self.use_rel_pos = use_rel_pos  
            if self.use_rel_pos:
                assert (
                    input_size is not None
                ), "Input size must be provided if using relative positional encoding."
                # 初始化相对位置嵌入参数
                self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
                self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
    
        # 前向传播函数
        def forward(self, x: torch.Tensor) -> torch.Tensor:  
            B, H, W, _ = x.shape
            # 对输入x进行线性变换得到查询、键、值
            qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            # 将查询、键、值拆分出来
            q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
    
            # 计算注意力权重
            attn = (q * self.scale) @ k.transpose(-2, -1)
    
            # 如果使用相对位置嵌入,将相对位置嵌入添加到注意力权重中
            if self.use_rel_pos:
                attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
    
            # 对注意力权重进行softmax归一化
            attn = attn.softmax(dim=-1)
            # 计算注意力加权后的值,并重新调整形状
            x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
            # 将注意力加权后的值进行线性变换
            x = self.proj(x)
    
            return x
  4. Define two functions window_partition and window_unpartition, which are used to divide and merge the input tensor into windows. These functions are used in the implementation of Vision Transformer to implement the window attention mechanism
    # 定义window_partition函数,用于将输入x分割为不重叠的窗口,并进行填充。
    def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """
        Partition into non-overlapping windows with padding if needed.
        Args:
            x (tensor): input tokens with [B, H, W, C].
            window_size (int): window size.
    
        Returns:
            windows: windows after partition with [B * num_windows, window_size, window_size, C].
            (Hp, Wp): padded height and width before partition
        """
        B, H, W, C = x.shape
    
        # 计算需要进行填充的行和列的数量
        pad_h = (window_size - H % window_size) % window_size
        pad_w = (window_size - W % window_size) % window_size
        # 如果需要进行填充,则使用F.pad函数进行填充
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        # 计算填充后的高度和宽度
        Hp, Wp = H + pad_h, W + pad_w
    
        # 将输入x重新调整形状为窗口大小的倍数
        x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        # 对调换维度进行重排列,并重新调整形状
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
        # 返回分割后的窗口和填充前的高度和宽度
        return windows, (Hp, Wp)
    
    
    # 定义window_unpartition函数,用于将窗口合并为原始序列,并移除填充。
    def window_unpartition(
        windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
    ) -> torch.Tensor:
        """
        Window unpartition into original sequences and removing padding.
        Args:
            windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
            window_size (int): window size.
            pad_hw (Tuple): padded height and width (Hp, Wp).
            hw (Tuple): original height and width (H, W) before padding.
    
        Returns:
            x: unpartitioned sequences with [B, H, W, C].
        """
        Hp, Wp = pad_hw
        H, W = hw
        B = windows.shape[0] // (Hp * Wp // window_size // window_size)
        # 将窗口重新调整为原始序列
        x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
        # 对调换维度进行重排列,并重新调整形状
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
    
        # 如果填充的高度或宽度大于原始高度或宽度,则移除填充部分
        if Hp > H or Wp > W:
            x = x[:, :H, :W, :].contiguous()
        # 返回合并后的序列
        return x
  5. Define two functions get_rel_pos and add_decomposed_rel_pos for handling relative position embedding. In the implementation of Vision Transformer, relative position embedding is used to provide relative position information between sequence elements to help the model better capture the relationship in the sequence. These functions are used to generate and apply relative positional embeddings
    # 定义get_rel_pos函数,根据查询和键的大小获取相对位置嵌入。
    def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
        """
        Get relative positional embeddings according to the relative positions of
            query and key sizes.
        Args:
            q_size (int): size of query q.
            k_size (int): size of key k.
            rel_pos (Tensor): relative position embeddings (L, C).
    
        Returns:
            Extracted positional embeddings according to relative positions.
        """
        # 计算相对距离的最大值
        max_rel_dist = int(2 * max(q_size, k_size) - 1)
        # 如果相对位置嵌入的形状与最大相对距离不一致,则进行插值处理
        if rel_pos.shape[0] != max_rel_dist:
            # 插值相对位置嵌入
            rel_pos_resized = F.interpolate(
                rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
                size=max_rel_dist,
                mode="linear",
            )
            rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
        else:
            rel_pos_resized = rel_pos
    
        # 根据形状的不同,使用短边的长度进行坐标缩放
        q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
        k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
        relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
    
        return rel_pos_resized[relative_coords.long()]
    
    
    # 定义add_decomposed_rel_pos函数,计算分解的相对位置嵌入
    def add_decomposed_rel_pos(
        attn: torch.Tensor,
        q: torch.Tensor,
        rel_pos_h: torch.Tensor,
        rel_pos_w: torch.Tensor,
        q_size: Tuple[int, int],
        k_size: Tuple[int, int],
    ) -> torch.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
        Args:
            attn (Tensor): attention map.
            q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
            rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
            rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
            q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
            k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
    
        Returns:
            attn (Tensor): attention map with added relative positional embeddings.
        """
        q_h, q_w = q_size
        k_h, k_w = k_size
        # 获取相对位置嵌入
        Rh = get_rel_pos(q_h, k_h, rel_pos_h)
        Rw = get_rel_pos(q_w, k_w, rel_pos_w)
    
        B, _, dim = q.shape
        r_q = q.reshape(B, q_h, q_w, dim)
        rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
        rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
    
        attn = (
            attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
        ).view(B, q_h * q_w, k_h * k_w)
    
        return attn
  6. Defines a PatchEmbed class for converting images to patch embeddings. It uses convolutional layers to convert an input image into a patch embedding representation of a specified dimension. In the forward pass, the input is projected through the convolutional layer, and the order of the dimensions is reversed so that the output has the shape of batch-height-width-channel

    # 定义PatchEmbed类,用于将图像转换为补丁嵌入。
    class PatchEmbed(nn.Module):
        """
        Image to Patch Embedding.
        """
    
        def __init__(
            self,
            kernel_size: Tuple[int, int] = (16, 16),
            stride: Tuple[int, int] = (16, 16),
            padding: Tuple[int, int] = (0, 0),
            in_chans: int = 3,
            embed_dim: int = 768,
        ) -> None:
            """
            Args:
                kernel_size (Tuple): kernel size of the projection layer.
                stride (Tuple): stride of the projection layer.
                padding (Tuple): padding size of the projection layer.
                in_chans (int): Number of input image channels.
                embed_dim (int): Patch embedding dimension.
            """
            # 使用super函数调用父类的初始化函数
            super().__init__()
    
            # 创建卷积层,用于将图像转换为补丁嵌入
            self.proj = nn.Conv2d(
                in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
            )
    
        # 前向传播函数
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            # 将输入x进行投影
            x = self.proj(x)
            # 调换维度的顺序,B C H W -> B H W C
            x = x.permute(0, 2, 3, 1)
            return x

1.2.2 prompt encoder

Divided into 2 categories: sparse (point/box/text), dense (mask)

  • For sparse points, boxes, and text
    points
    , map to 256-dimensional vectors, including: positional encoding representing the position of the point, plus 2 learnable embeddings representing that the point is foreground/background
    Sparse prompts are mapped to 256-dimensional vectorial embeddings as follows. A point is represented as the sum of a positional encoding [95] of the point's location and one of two learned embeddings that indicate if the point is either in the foreground or background. The box is represented by an embedding pair: 1) can
    be
    learned The embedding of 2) represents the upper left corner, and 2) the learnable embedding represents the lower right corner

    Text
    is encoded by the CLIP model
  • For the dense mask,
    use the mask of 1/4 resolution of the input image, then use the (2,2) convolution kernel, stride-2 output channel is 4 and 16, and then use the (1,1) convolution kernel to increase the channel to 256
    We input masks at a 4× lower resolution than the input image, then downscale an additional 4× using two 2×2, stride-2 convolutions with output channels 4 and 16, respectively. A final 1×1 convolution maps the channel dimension to 256.
    The mask and iamge embedding are multiplied by element-wise (element-by-element multiplication, which can be understood as the feature of the mask weighting the feature of the image)

Its code is implemented as

import numpy as np
import torch
from torch import nn
from typing import Any, Optional, Tuple, Type
from .common import LayerNorm2d


class PromptEncoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        image_embedding_size: Tuple[int, int],
        input_image_size: Tuple[int, int],
        mask_in_chans: int,
        activation: Type[nn.Module] = nn.GELU,
    ) -> None:
        """
        SAM模型的PromptEncoder类,用于编码输入到遮罩解码器的提示。

        参数:
          embed_dim (int): 提示的嵌入维度
          image_embedding_size (tuple(int, int)): 图像嵌入的空间尺寸,格式为(H, W)。
          input_image_size (int): 输入到图像编码器的图像填充尺寸,格式为(H, W)。
          mask_in_chans (int): 用于编码输入遮罩的隐藏通道数。
          activation (nn.Module): 用于编码输入遮罩时使用的激活函数。
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.input_image_size = input_image_size
        self.image_embedding_size = image_embedding_size
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)

        self.num_point_embeddings: int = 4  # 正/负点 + 2个框角
        point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
        self.point_embeddings = nn.ModuleList(point_embeddings)
        self.not_a_point_embed = nn.Embedding(1, embed_dim)

        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans // 4),
            activation(),
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans),
            activation(),
            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )
        self.no_mask_embed = nn.Embedding(1, embed_dim)

    def get_dense_pe(self) -> torch.Tensor:
        """
        返回用于编码点提示的位置编码,应用于与图像编码尺寸相同的密集点集。

        返回:
          torch.Tensor: 形状为1x(embed_dim)x(embedding_h)x(embedding_w)的位置编码。
        """
        return self.pe_layer(self.image_embedding_size).unsqueeze(0)

    def _embed_points(
        self,
        points: torch.Tensor,
        labels: torch.Tensor,
        pad: bool,
    ) -> torch.Tensor:
        """嵌入点提示。"""
        points = points + 0.5  # 移动到像素的中心
        if pad:
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
            points = torch.cat([points, padding_point], dim=1)
            labels = torch.cat([labels, padding_label], dim=1)
        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
        point_embedding[labels == -1] = 0.0
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        return point_embedding

    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """嵌入框提示。"""
        boxes = boxes + 0.5  # 移动到像素的中心
        coords = boxes.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        return corner_embedding

    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
        """嵌入遮罩输入。"""
        mask_embedding = self.mask_downscaling(masks)
        return mask_embedding

    def _get_batch_size(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> int:
        """
        根据输入提示的批大小获取输出的批大小。
        """
        if points is not None:
            return points[0].shape[0]
        elif boxes is not None:
            return boxes.shape[0]
        elif masks is not None:
            return masks.shape[0]
        else:
            return 1

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

    def forward(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        嵌入不同类型的提示,返回稀疏和密集的嵌入。

        参数:
          points (tuple(torch.Tensor, torch.Tensor) or none): 要嵌入的点坐标和标签。
          boxes (torch.Tensor or none): 要嵌入的框。
          masks (torch.Tensor or none): 要嵌入的遮罩。

        返回:
          torch.Tensor: 稀疏的点和框嵌入,形状为BxNx(embed_dim),其中N由输入点和框的数量决定。
          torch.Tensor: 密集的遮罩嵌入,形状为Bx(embed_dim)x(embed_H)x(embed_W)。
        """
        bs = self._get_batch_size(points, boxes, masks)
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
        if points is not None:
            coords, labels = points
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        if boxes is not None:
            box_embeddings = self._embed_boxes(boxes)
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

        if masks is not None:
            dense_embeddings = self._embed_masks(masks)
        else:
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

        return sparse_embeddings, dense_embeddings


class PositionEmbeddingRandom(nn.Module):
    """
    使用随机空间频率的位置编码。
    """

    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().__init__()
        if scale is None or scale <= 0.0:
            scale = 1.0
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """对归一化到[0,1]的点进行位置编码。"""
        # 假设坐标在[0, 1]^2的正方形内,并具有d_1 x ... x d_n x 2的形状
        coords = 2 * coords - 1
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords
        # 输出形状为d_1 x ... x d_n x C
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """为指定大小的网格生成位置编码。"""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device
        grid = torch.ones((h, w), device=device, dtype=torch.float32)
        y_embed = grid.cumsum(dim=0) - 0.5
        x_embed = grid.cumsum(dim=1) - 0.5
        y_embed = y_embed / h
        x_embed = x_embed / w

        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
        return pe.permute(2, 0, 1)  # C x H x W

    def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """对未归一化到[0,1]的点进行位置编码。"""
        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
        return self._pe_encoding(coords.to(torch.float))  # B x N x C

1.2.3 mask decoder

Mask decoder module: insert a learnable token in prompt embeddings for the output of docoder

For the left part of the figure below, perform the following 4 steps in sequence

  1. prompt toekns+output tokens进行self attn
    self-attention on the tokens
  2. Use the obtained token and image embedding for cross attn (token as Q)
    cross-attention from tokens (as queries) to the image embedding
  3. point-wise MLP generates a token
    that point-wise MLP updates each token
  4. Use the image embedding and the token of step 3 for cross atten (image embedding as Q)
    cross-attention from the image embedding (as queries) to tokens

Repeat the above steps twice, then connect attn through the residual, and finally output masks and iou scores. This code is implemented as

import torch
from torch import Tensor, nn

import math
from typing import Tuple, Type

from .common import MLPBlock


class TwoWayTransformer(nn.Module):
    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
    ) -> None:
        """
        使用位置嵌入提供的查询,对输入图像进行注意力操作的Transformer解码器。

        参数:
          depth (int): Transformer中的层数
          embedding_dim (int): 输入嵌入的通道维度
          num_heads (int): 多头注意力的头数。embedding_dim必须是num_heads的倍数
          mlp_dim (int): MLP块内部的通道维度
          activation (nn.Module): MLP块中使用的激活函数
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )

        self.final_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

    def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        参数:
          image_embedding (torch.Tensor): 要进行注意力操作的图像。形状应为B x embedding_dim x h x w,其中h和w可以是任意值。
          image_pe (torch.Tensor): 添加到图像的位置编码。形状必须与image_embedding相同。
          point_embedding (torch.Tensor): 添加到查询点的嵌入。形状必须为B x N_points x embedding_dim,其中N_points可以是任意值。

        返回:
          torch.Tensor: 处理后的point_embedding
          torch.Tensor: 处理后的image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # 准备查询
        queries = point_embedding
        keys = image_embedding

        # 应用Transformer块和最终的LayerNorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )

        # 应用从点到图像的最终注意力层
        q = queries + point_embedding
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)

        return queries, keys


class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        一个具有四个层的Transformer块:
        (1) 稀疏输入的自注意力,
        (2) 将稀疏输入与密集输入的交叉注意力,
        (3) 稀疏输入的MLP块,
        (4) 将密集输入与稀疏输入的交叉注意力。

        参数:
          embedding_dim (int): 嵌入的通道维度
          num_heads (int): 注意力层中的头数
          mlp_dim (int): MLP块的隐藏维度
          activation (nn.Module): MLP块的激活函数
          skip_first_layer_pe (bool): 是否跳过第一层的位置编码
        """
        super().__init__()
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )

        self.skip_first_layer_pe = skip_first_layer_pe

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        # 自注意力块
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # 交叉注意力块,将token与图像嵌入进行注意力操作
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP块
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # 交叉注意力块,将图像嵌入与token进行注意力操作
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys


class Attention(nn.Module):
    """
    允许在将查询、键和值投影后缩小嵌入大小的注意力层。
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.internal_dim = embedding_dim // downsample_rate
        self.num_heads = num_heads
        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # 输入投影
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # 分割为头部
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        # 注意力操作
        _, _, _, c_per_head = q.shape
        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
        attn = attn / math.sqrt(c_per_head)
        attn = torch.softmax(attn, dim=-1)

        # 获取输出
        out = attn @ v
        out = self._recombine_heads(out)
        out = self.out_proj(out)

        return out

For the right part of the figure below

  • After running the decoder, we upsample
    the updated image embedding by 4× with two transposed convolutional 16 layers (now it's downscaled 4× relative to the input image)
  • Then, the token participates in the image embedding again, that is, the updated output token embedding is passed to a small 3-layer MLP, which outputs a vector matching the channel dimension of the upgraded image embedding. Then, the tokens attend once more to the image embedding
    and We pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding
  • Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP's output
    Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP's output

Among them, there are several issues worth mentioning

  1. The embedding dimension used by the transformer is 256, and the inner size of the MLP block is larger, 2048, but the MLP is only applied to hint values ​​with relatively few hint values ​​(rarely larger than 20). However, in our cross-attention layers with 64×64 image embeddings, we reduce the channel dimension of query, key and value by 2 times to 128 for computational efficiency, and all attention layers use 8 heads. The
    transformer uses an embedding dimension of 256. The transformer MLP blocks have a large internal dimension of 2048, but the MLP is applied only to the prompt tokens for which there are relatively few (rarely greater than 20). However, in cross-attention layers where we have a 64×64 image embedding, we reduce the channel dimension of the queries, keys, and values ​​by 2× to 128 for computational efficiency. All attention layers use 8 heads.
  2. The transposed convolutions used to upscale the output image embedding are 2×2, with output channel dimensions of 64 and stride 2 of 32, with GELU activations, and finally separate them by layer normalization The transposed convolutions used to upscale the output
    image embeddings are 2×2, stride 2 with output channel dimensions of 64 and 32 and have GELU activations. They are separated by layer normalization.
  3. In order to solve the problem of output ambiguity (a hint may generate multiple masks, such as a point on the clothes, which can represent both the clothes and the person wearing the clothes), predict and output multiple masks "that is, use a small number of output tokens and predict multiple masks at the same time . masks, instead of predicting a single mask, three masks are predicted by default, because three layers (whole, part, and subpart) are usually sufficient to describe a nested mask, namely three layers (whole, part, and subpart) are often enough to describe nested masks

    During the training process, only the smallest loss is returned. In order to sort the masks, a small head is added to predict the mask and the iou of the target. When multiple
    prompts are input, the generated mask will be closer , in order to reduce loss degradation and ensure that a clear mask is obtained, only one mask is predicted at this time (as the fourth prediction mask, it is only predicted when there are multiple prompts, and it is not used when there is a single prompt, that is, This is accomplished by adding a fourth output token for an additional mask prediction. This fourth mask is never returned for a single prompt and is the only mask returned for multiple prompts. )

Its code is implemented as (defining a MaskDecoder class, which is used to predict the mask embedded in a given image and hints, using the Transformer architecture. At the same time, it also defines an MLP class, that is, a multi-layer perceptron network)

import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple, Type
from .common import LayerNorm2d

# 定义MaskDecoder类,继承自nn.Module
class MaskDecoder(nn.Module):
    # 构造函数
    def __init__(
        self,
        *,
        transformer_dim: int,        # Transformer的维度
        transformer: nn.Module,
        num_multimask_outputs: int = 3,    # 多重掩码输出的数量,默认为3
        activation: Type[nn.Module] = nn.GELU,  # 激活函数类型,默认为nn.GELU
        iou_head_depth: int = 3,           # 预测掩码质量的MLP的深度,默认为3
        iou_head_hidden_dim: int = 256,    # 预测掩码质量的MLP的隐藏维度,默认为256
    ) -> None:

        super().__init__()        # 调用父类的初始化函数
        self.transformer_dim = transformer_dim    # 初始化Transformer的维度
        self.transformer = transformer            # 初始化Transformer模块

        # 初始化多重掩码输出的数量
        self.num_multimask_outputs = num_multimask_outputs
        self.iou_token = nn.Embedding(1, transformer_dim)    # 初始化IOU嵌入
        self.num_mask_tokens = num_multimask_outputs + 1     # 初始化掩码token的数量

        # 初始化掩码token的嵌入
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

        # 初始化输出缩放的网络
        self.output_upscaling = nn.Sequential(
            # 卷积反卷积2d
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),

            # 激活函数
            activation(),

            # 卷积反卷积2d
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),

            activation(),
        )
        # 初始化输出超网络的MLP列表
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

        # 初始化IOU预测头
        self.iou_prediction_head = MLP(
            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
        )

    # 前向传播函数
    def forward(
        self,
        image_embeddings: torch.Tensor,       # 图像的嵌入表示
        image_pe: torch.Tensor,               # 图像的位置编码
        sparse_prompt_embeddings: torch.Tensor,    # 稀疏提示的嵌入表示
        dense_prompt_embeddings: torch.Tensor,     # 密集提示的嵌入表示
        multimask_output: bool,                    # 是否返回多个掩码
    ) -> Tuple[torch.Tensor, torch.Tensor]:        # 预测的掩码

        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # 根据multimask_output选择掩码输出
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # 准备输出
        return masks, iou_pred

    # 预测掩码函数
    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        # 预测掩码。参考'forward'获取更多细节
        """
        # 拼接输出token
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # 在batch方向上扩展每个图像数据,以便在mask上进行处理
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # 运行Transformer
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # 缩放mask嵌入并使用mask tokens预测masks
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # 生成mask质量预测
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred


# MLP类,继承自nn.Module
class MLP(nn.Module):
    # 构造函数
    def __init__(
        self,
        input_dim: int,         # 输入维度
        hidden_dim: int,        # 隐藏层维度
        output_dim: int,        # 输出维度
        num_layers: int,        # 层数
        sigmoid_output: bool = False,    # 是否在输出上应用sigmoid函数
    ) -> None:

        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        # 初始化各层
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.sigmoid_output = sigmoid_output

    # 前向传播函数
    def forward(self, x):
        # 遍历每一层,逐层处理输入
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)

        # 如果sigmoid_output为真,对输出应用sigmoid函数
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x

After implementing the above three structures, they can be directly called during actual segmentation.

import torch
from torch import nn
from torch.nn import functional as F

from typing import Any, Dict, List, Tuple

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder


class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,
        prompt_encoder: PromptEncoder,
        mask_decoder: MaskDecoder,
        pixel_mean: List[float] = [123.675, 116.28, 103.53],
        pixel_std: List[float] = [58.395, 57.12, 57.375],
    ) -> None:
        """
        SAM从图像和输入提示中预测对象的遮罩。

        参数:
          image_encoder (ImageEncoderViT): 用于将图像编码为图像嵌入的主干。
          prompt_encoder (PromptEncoder): 对各种类型的输入提示进行编码。
          mask_decoder (MaskDecoder): 从图像嵌入和编码的提示中预测遮罩。
          pixel_mean (list(float)): 输入图像中像素归一化的平均值。
          pixel_std (list(float)): 输入图像中像素归一化的标准差。
        """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    @torch.no_grad()
    def forward(
        self,
        batched_input: List[Dict[str, Any]],
        multimask_output: bool,
    ) -> List[Dict[str, torch.Tensor]]:
        """
        从提供的图像和提示中端到端地预测遮罩。
        如果事先不知道提示,建议使用SamPredictor而不是直接调用模型。

        参数:
          batched_input (list(dict)): 输入图像的列表,每个图像是一个包含以下键的字典。如果不存在提示键,则可以排除。
              'image': 图像作为3xHxW格式的torch张量,已经转换为模型输入格式。
              'original_size': (tuple(int, int)) 转换前图像的原始大小,格式为(H, W)。
              'point_coords': (torch.Tensor) 该图像的批处理点提示,形状为BxNx2。已转换为模型的输入帧。
              'point_labels': (torch.Tensor) 批处理点提示的标签,形状为BxN。
              'boxes': (torch.Tensor) 批处理的框输入,形状为Bx4。已转换为模型的输入帧。
              'mask_inputs': (torch.Tensor) 输入模型的批处理遮罩输入,形式为Bx1xHxW。
          multimask_output (bool): 模型是否应该预测多个消除歧义的遮罩,还是返回单个遮罩。

        返回:
          (list(dict)): 输入图像的列表,每个元素是一个包含以下键的字典。
              'masks': (torch.Tensor) 批处理的二进制遮罩预测,形状为BxCxHxW,其中B是输入提示的数量,C由multimask_output决定,(H, W)是图像的原始大小。
              'iou_predictions': (torch.Tensor) 遮罩质量的模型预测,形状为BxC。
              'low_res_logits': (torch.Tensor) 低分辨率的逻辑张量,形状为BxCxHxW,其中H=W=256。可以作为遮罩输入传递给后续的预测迭代。
        """
        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
        image_embeddings = self.image_encoder(input_images)

        outputs = []
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
            low_res_masks, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            masks = self.postprocess_masks(
                low_res_masks,
                input_size=image_record["image"].shape[-2:],
                original_size=image_record["original_size"],
            )
            masks = masks > self.mask_threshold
            outputs.append(
                {
                    "masks": masks,
                    "iou_predictions": iou_predictions,
                    "low_res_logits": low_res_masks,
                }
            )
        return outputs

    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        """
        去除填充并将遮罩放大到原始图像大小。

        参数:
          masks (torch.Tensor): MaskDecoder生成的批处理遮罩,格式为BxCxHxW。
          input_size (tuple(int, int)): 输入到模型的图像的大小,格式为(H, W)。用于去除填充。
          original_size (tuple(int, int)): 调整为输入模型的图像的原始大小,格式为(H, W)。

        返回:
          (torch.Tensor): 格式为BxCxHxW的批处理遮罩,其中(H, W)由original_size给出。
        """
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """归一化像素值并填充为方形输入。"""
        # 归一化颜色
        x = (x - self.pixel_mean) / self.pixel_std

        # 填充
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x

1. 2.4 Model Training

Simulate the process of interactive segmentation during training, randomly select the foreground point or box from the target mask, the point is selected from the gt mask, and add 10% noise to the long side of the box, up to 20 pixels

After the first prompt predicts the mask, the follow-up is to sample points from the area where the predicted mask and gt mask are different

  • If the newly generated point is FN, it is used as the foreground
  • If FP, as background

At the same time, the predicted mask (unthresholded mask logits replaces the binarized mask, does not filter the threshold, the default is 0), as a prompt as an iteration

During the training process, it was found that it is more appropriate to use 8 sampling points (compared to 16, there is no obvious gain). In order to encourage the model to benefit from the mask, 2 iterations do not use new sampling points, a total of 11 iterations, one is initialized prompt input, then 8 iterations of the above, plus 2 iterations without resampling points (this can refine the mask). Since the mask decoder is relatively light, more iterations can be performed

  • Loss
    mask uses focal loss and dice loss for linear combination, coefficient (20:1), iou uses mse loss
  • Training time
    256 A100 GPUs, 3-5 days (A100 price is about 60,000, 256, more than 10 million, you know..)

1.3 data engine (data engine): auxiliary manual, semi-automatic, fully automatic

  • Assisted manual annotation Annotation
    is performed through SAM's browser-based interactive segmentation tool, through the "brush" and "eraser" tools. The model can output masks in real time. It is recommended that the annotators first mark the objects they named, and mark them in layer order. If a mask is marked for more than 30s, the next
    SAM is first trained with the public data set, and then the newly added mark mask is used. train. With more data, the image-encoder's ability is stronger, and it has been retrained 6 times. As the model improves, the average labeling time for each mask increases from 34s to 14s, and the average number of masks per image increases from 22 to 44. During this process, 4.3 million masks were collected from 120,000 images.
  • Semi-automatically
    increases the diversity of masks, first detects credible masks, then fills the image with predicted masks, and lets annotators label unlabeled masks. In order to detect a credible mask, a box detector of the same category is first trained with the mask of the first step. In a semi-automatic process, 5.9 million masks were generated from 180,000 images. Using the newly collected data to retrain the model, the average labeling time returned to 34s, because the new masks are more difficult. The mask is increased from 44 to 72 per image.
  • Fully automatic
    use of the first two steps to obtain a large number of and diverse masks, combined with the model can also output effective masks according to ambiguous inputs (refer to mask encoder), and generate (32, 32) grid points for the image, each A point predicts a sequence of masks, and if a point falls on a part, subpart, the model returns part, subpart, and whole objects. At the same time, filter the confident (credible mask) through the predicted iou  , and select a stable mask (stable mask, in a similar mask, the probability threshold is between 0.5-δ and 0.5-δ); finally, filter through nms In order to improve the repeated mask in confident and stable
    , the mask is relatively small, and the image is cropped by enlarging the image to deal with the situation of multiple mask coverage

In the end, 1.1 billion high-quality masks were generated on the 11 million data set

Data situation

  • Image: Get 11 million images from partners, resampled to 1500 pixels by short side
  • Mask: 99.1% are automatically generated. Through comparative analysis, the quality of the automatically generated mask is also very high. In order to evaluate the quality, 500 images (about 50,000 masks) were randomly selected, and professional labelers were asked to label them. Through comparison, it was found that 94% of the masks had more than 90% of iou
  • Data distribution is wider, data is obtained from all over the world, more masks, less biased data

References and Recommended Reading

  1. SAM original paper published by Meta
  2. Several interpretations of SAM papers: [Paper Interpretation] MetaAi SAM (Segment Anything) splits everything , SAM interpretation PPT

Guess you like

Origin blog.csdn.net/v_JULY_v/article/details/131503971