【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块ProEnco网络解析

【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析

Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。本博客将讲解Prompt encoder模块的深度学习网络代码。


前言

在详细解析SAM代码之前,首要任务是成功运行SAM代码【win10下参考教程】,后续学习才有意义。本博客讲解Prompt encoder模块的深度网络代码,不涉及其他功能模块代码。


PromptEncoder网络简述

SAM模型关于ProEnco网络的配置

博主以sam_vit_b为例,详细讲解ViT网络的结构。
代码位置:segment_anything/build_sam.py

def build_sam_vit_b(checkpoint=None):
    return _build_sam(
        # 图像编码channel
        encoder_embed_dim=768,
        # 主体编码器的个数
        encoder_depth=12,
        # attention中head的个数
        encoder_num_heads=12,
        # 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)
        encoder_global_attn_indexes=[2, 5, 8, 11],
        # 权重
        checkpoint=checkpoint,
    )

sam模型中prompt_encoder模块初始化

prompt_encoder=PromptEncoder(
    # 提示编码channel(和image_encoder输出channel一致,后续会融合)
    embed_dim=prompt_embed_dim,
    # mask的编码尺寸(和image_encoder输出尺寸一致)
    image_embedding_size=(image_embedding_size, image_embedding_size),
    # 输入图像的标准尺寸
    input_image_size=(image_size, image_size),
    # 对输入掩码编码的通道数
    mask_in_chans=16,
),

ProEnco网络结构与执行流程

Prompt encoder源码位置:segment_anything/modeling/prompt_encoder.py
ProEnco网络(PromptEncoder类)结构参数配置。

def __init__(
    self,
    embed_dim: int,                         # 提示编码channel
    image_embedding_size: Tuple[int, int],  # # mask的编码尺寸
    input_image_size: Tuple[int, int],      # 输入图像的标准尺寸
    mask_in_chans: int,                     # 输入掩码编码的通道数
    activation: Type[nn.Module] = nn.GELU,  # 激活层
) -> None:
    super().__init__()
    self.embed_dim = embed_dim              # 提示编码channel
    self.input_image_size = input_image_size                # 输入图像的标准尺寸
    self.image_embedding_size = image_embedding_size        # mask的编码尺寸
    self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
    self.num_point_embeddings: int = 4                      # 4个点:正负点,框的俩个点
    point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]   # 4个点的嵌入向量
    # nn.ModuleList它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器
    self.point_embeddings = nn.ModuleList(point_embeddings)                     # 4个点的嵌入向量添加到网络
    self.not_a_point_embed = nn.Embedding(1, embed_dim)                         # 不是点的嵌入向量
    self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])           # mask的输入尺寸
    self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
        nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans // 4),
        activation(),
        nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans),
        activation(),
        nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
    )
    self.no_mask_embed = nn.Embedding(1, embed_dim)                         # 没有mask输入时 嵌入向量

SAM模型中ProEnco网络结构如下图所示:

ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:

  1. Embed_Points:标记点编码(标记点由点转变为向量)
  2. Embed_Boxes:标记框编码(标记框由点转变为向量)
  3. Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
def forward(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 获得 batchsize  当前predict为1
    bs = self._get_batch_size(points, boxes, masks)
    
    # -----sparse_embeddings----
    sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
    if points is not None:
        coords, labels = points
        point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
        sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
    if boxes is not None:
        box_embeddings = self._embed_boxes(boxes)
        sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
    # -----sparse_embeddings----
    
    # -----dense_embeddings----
    if masks is not None:
        dense_embeddings = self._embed_masks(masks)
    else:
        dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
            bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
        )
    # -----dense_embeddings----
    
    return sparse_embeddings, dense_embeddings

获取batchsize

def _get_batch_size(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> int:
    if points is not None:
        return points[0].shape[0]
    elif boxes is not None:
        return boxes.shape[0]
    elif masks is not None:
        return masks.shape[0]
    else:
        return 1

获取设备型号

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

ProEnco网络基本步骤代码详解

Embed_Points


标记点预处理,将channel由2变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

2:坐标(h,w)
embed_dim:提示编码的channel

Embed_Points结构如下图所示:

def _embed_points(
    self,
    points: torch.Tensor,
    labels: torch.Tensor,
    pad: bool,
) -> torch.Tensor:
    # 移到像素中心
    points = points + 0.5
    # points和boxes联合则不需要pad
    if pad:
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1
        points = torch.cat([points, padding_point], dim=1)                          # B,N+1,2
        labels = torch.cat([labels, padding_label], dim=1)                          # B,N+1
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # B,N+1,2f
    # labels为-1是非标记点,设为非标记点权重
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    # labels为0是背景点,加上背景点权重
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    # labels为1的目标点,加上目标点权重
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

个人理解:pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。

Embed_Boxes


标记框预处理,将channel由4到2再变成embed_dim(MatMul:forward_with_coords),然后再加上位置编码权重。

4:坐标(h1,w1,h2,w2) -->起始点与末位点
2:坐标(h,w)–>4 reshape 成 2×2
embed_dim:提示编码的channel

Embed_Boxes结构如下图所示:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    # 移到像素中心
    boxes = boxes + 0.5
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    #
    # 目标框起始点的和末位点分别加上权重
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    return corner_embedding

个人理解:boxes reshape 后 batchsize是会增加的,B,N,4–>BN,2,2
因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个。

Embed_Masks


mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。
Embed_Masks结构如下图所示:

def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
    # mask下采样4倍
    mask_embedding = self.mask_downscaling(masks)
    return mask_embedding
# 在PromptEncoder的__init__定义
self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans // 4),
    activation(),
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans),
    activation(),
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )

假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask。

# 在PromptEncoder的forward定义
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
    bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)

PositionEmbeddingRandom

用于将标记点和标记框的坐标进行提示编码预处理。

def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
    super().__init__()
    if scale is None or scale <= 0.0:
        scale = 1.0
    # 理解为模型的常数 [2,f]
    self.register_buffer(
        "positional_encoding_gaussian_matrix",
        scale * torch.randn((2, num_pos_feats)),
    )

将标记点的坐标具体的位置转变为[0~1]之间的比例位置

def forward_with_coords(
    self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
    coords = coords_input.clone()
    # 将坐标位置缩放到[0~1]之间
    coords[:, :, 0] = coords[:, :, 0] / image_size[1]
    coords[:, :, 1] = coords[:, :, 1] / image_size[0]
    # B,N+1,2-->B,N+1,2f
    return self._pe_encoding(coords.to(torch.float))

标记点位置编码

因为sin和cos,编码的值归一化至 [-1,1],源码注释是[0,1],博主经过实验发现注释不对

def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
    # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
    coords = 2 * coords - 1
    # B,N+1,2 × 2,f --> B,N+1,f
    coords = coords @ self.positional_encoding_gaussian_matrix
    coords = 2 * np.pi * coords
    # outputs d_1 x ... x d_n x C shape
    # B,N+1,2f
    return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

总结

尽可能简单、详细的介绍SAM中Prompt encoder模块的ProEnco网络的代码。后续会讲解SAM的其他模块的代码。

猜你喜欢

转载自blog.csdn.net/yangyu0515/article/details/130389786