mmdetection之Detr源码解读

前言

 本文首先简要介绍Detr论文原理,之后在介绍mmdetection中的detr实现。
 论文地址:chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.53yu.com/pdf/2005.12872.pdf,

1、原理简要介绍

在这里插入图片描述
 整体流程:在给定一张输入图像后,1)特征向量提取:首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外,由于仅用最后一层特征图,故对小目标检测不友好,这也是后续deformable detr改进的原因。 2)添加位置编码信息:经F拉平成一维张量并添加上位置编码信息得到I。3)Transformer中encoder部分4)Transformer中decoder部分,学习位置嵌入object queries。5)FFN部分:6)后续匈牙利匹配+损失计算。

2、mmdetection中源码介绍

2.1. 整体逻辑

  Detr的内部逻辑如下:在mmdet/models/detector/single_stage.py。即首先提取图像特征向量,之后经过DetrHead来计算最终的损失。

def forward_train(self,
                  img,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    x = self.extract_feat(img) # 提取图像特征向量  
    # 经过DetrHead得到loss                   
    losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                          gt_labels, gt_bboxes_ignore)
    return losses

2.2. 图像特征向量提取

 mmdet中提取图像特征向量的config配置文件如下,可以发现用ResNet50并只提取了最后一层特征层,即out_indices=(3,)。关于内部原理参见我的博文:mmdet之backbone介绍。

backbone=dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(3, ),     # detr仅要resnet50的最后一层特征图,并不需要FPN
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

2.3. 给图像特征向量添加位置编码信息

 本部分代码来自mmdet/models/dense_heads/detr_head.py。
  mmdet中生成位置编码信息借助的是mask矩阵,所谓的mask就是为了统一批次大小而对图像进行了pad,被填充的部分在后续计算多头注意力时应该舍弃,故需要一个mask矩阵遮挡住,具体形状为[batch, h,w]这里先贴下生成mask的过程:

batch_size = x.size(0)   
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
masks = x.new_ones((batch_size, input_img_h, input_img_w))  # [b,838,768]
for img_id in range(batch_size):
    img_h, img_w, _ = img_metas[img_id]['img_shape']    # 创建了一个mask,非0代表无效区域, 0 代表有效区域
    masks[img_id, :img_h, :img_w] = 0                   # 将pad部分置为1,非pad部分置为0.

 我这里简单贴下mask示意图:
在这里插入图片描述
  在有了mask基础上[batch,256,h,w],注意此时的hw是原图大小的;而输入图像的经过resnet50下采样后hw已经变了,所以还需进一步将mask下采样成和图像特征向量一样的shape。代码如下:

# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
    masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,27,24]

 后续便可以生成位置编码部分(mmdet/models/utils/position_encoding.py),该函数给masks的每个像素位置生成了一个256维的唯一的位置向量。我这简单写了个测试脚本:

import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils.positional_encoding import POSITIONAL_ENCODING  # 加载注册器

positional_encoding = dict(type='SinePositionalEncoding', num_feats=128, normalize=True)
self = build_positional_encoding(positional_encoding)
self.eval()
mask = torch.tensor([[[0,0,1],[0,0,1],[1,1,1]]], dtype= torch.uint8) # [1,3,3]
out = self(mask)          # [b,256,h,w]

 感兴趣可以看下mmdet关于位置编码这部分实现逻辑(只是做了简单注释):

def forward(self, mask):
    """Forward function for `SinePositionalEncoding`.

    Args:
        mask (Tensor): ByteTensor mask. Non-zero values representing
            ignored positions, while zero values means valid positions
            for this image. Shape [bs, h, w].

    Returns:
        pos (Tensor): Returned position embedding with shape
            [bs, num_feats*2, h, w].
    """
    # For convenience of exporting to ONNX, it's required to convert
    # `masks` from bool to int.
    mask = mask.to(torch.int)
    not_mask = 1 - mask       # 取反将1的位置视为图像区域
    y_embed = not_mask.cumsum(1, dtype=torch.float32)  # 累加1得到y方向坐标 [h]
    x_embed = not_mask.cumsum(2, dtype=torch.float32)  # 累加1得到x方向坐标 [w]
    # 归一化过程就是除以坐标中的max,而最后一行/列就是累加的最大的向量
    if self.normalize:
        y_embed = (y_embed + self.offset) / \
                  (y_embed[:, -1:, :] + self.eps) * self.scale # 取最后一行
        x_embed = (x_embed + self.offset) / \
                  (x_embed[:, :, -1:] + self.eps) * self.scale # 取最后一列
    # 创建一个[128]的特征向量
    dim_t = torch.arange(
        self.num_feats, dtype=torch.float32, device=mask.device)
    dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)  # 归一化的 [0,0,1,1,2,2,...,64,64]乘温度系数
    # 行列坐标分别除以dim_t得到每个点的128维的行列特征向量
    pos_x = x_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    pos_y = y_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    # use `view` instead of `flatten` for dynamically exporting to ONNX
    B, H, W = mask.size()
    # 分别采样奇数和偶数位置并执行sin和cos,并拼接[b,h,w,128]
    pos_x = torch.stack(
        (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    pos_y = torch.stack(
        (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    # 最后将横纵坐标拼接得到每个点唯一的256维度的位置向量
    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)     # [b,256,h,w]
    return pos

3、送入Transformer

3.1. 整体逻辑

  在得到图像特征向量x=[b,c,h,w]、masks[b,h,w]矩阵以及位置编码pos_embed[b,256,h,w]后,便可送入Transformer,关键是厘清encoder和decoder的QKV分别指啥,看代码:

扫描二维码关注公众号,回复: 14970724 查看本文章
memory = self.encoder(
    query=x,                     # [hw,b,c]
    key=None,
    value=None,
    query_pos=pos_embed,         # [hw,b,c]
    query_key_padding_mask=mask) # [b,hw]
target = torch.zeros_like(query_embed)   # decoder初始化全0
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
    query=target,                # 全0的target, 后续在MultiHeadAttn中执行了
    key=memory,                  # query = query + query_pos又加回去了。
    value=memory,
    key_pos=pos_embed,
    query_pos=query_embed,       # [num_query, bs, dim]
    key_padding_mask=mask)
out_dec = out_dec.transpose(1, 2)

  其中encoder中q就是x,kv分别为None,query_pos代表位置编码,而query_key_padding_mask就是mask。decoder的q是全0的target,后续decoder会迭代更新q,而kv则 是memory,即encoder的输出;key_pos依旧是k的位置信息;query_embed即论文中Object query,可学习位置信息;key_padding_mask依然是mask。

3.2. encoder部分

  先看下encoder初始化部分,内部循环调用了6次BaseTransformerLayer,因此只需讲解一层EncoderLayer即可。

encoder=dict(
    type='DetrTransformerEncoder',
    num_layers=6,                        # 经过6层Layer
    transformerlayers=dict(              # 每层layer内部使用多头注意力
        type='BaseTransformerLayer',
        attn_cfgs=[
            dict(
                type='MultiheadAttention',
                embed_dims=256,           
                num_heads=8,
                dropout=0.1)
        ],
        feedforward_channels=2048,        # FFN中间层的维度   
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程

 在来看下BaseTransformerLayer的forward部分,该部分可以损失detr的核心部分了,因为本质上mmdet内部只是封装了pytorch现有的nn.MultiHeadAtten函数。所以,需要理解nn.MultiHeadAttn中两种mask参数的含义,限于篇幅原因,这里可参考nn.Transformer来理解这两个mask。 不过简单理解就是:attn_mask在detr中没用到,仅用key_padding_mask。attn_mask是为了遮挡未来文本信息用的,而图像可以看到全部的信息,因此不需要用attn_mask。

def forward(self,
            query,
            key=None,
            value=None,
            query_pos=None,
            key_pos=None,
            attn_masks=None,
            query_key_padding_mask=None,
            key_padding_mask=None,
            **kwargs):
    """Forward function for `TransformerDecoderLayer`.

    **kwargs contains some specific arguments of attentions.

    Args:
        query (Tensor): The input query with shape
            [num_queries, bs, embed_dims] if
            self.batch_first is False, else
            [bs, num_queries embed_dims].
        key (Tensor): The key tensor with shape [num_keys, bs,
            embed_dims] if self.batch_first is False, else
            [bs, num_keys, embed_dims] .
        value (Tensor): The value tensor with same shape as `key`.
        query_pos (Tensor): The positional encoding for `query`.
            Default: None.
        key_pos (Tensor): The positional encoding for `key`.
            Default: None.
        attn_masks (List[Tensor] | None): 2D Tensor used in
            calculation of corresponding attention. The length of
            it should equal to the number of `attention` in
            `operation_order`. Default: None.
        query_key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_queries]. Only used in `self_attn` layer.
            Defaults to None.
        key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_keys]. Default: None.

    Returns:
        Tensor: forwarded results with shape [num_queries, bs, embed_dims].
    """
    norm_index = 0
    attn_index = 0
    ffn_index = 0
    identity = query
    if attn_masks is None:
        attn_masks = [None for _ in range(self.num_attn)]
    elif isinstance(attn_masks, torch.Tensor):
        attn_masks = [
            copy.deepcopy(attn_masks) for _ in range(self.num_attn)
        ]
        warnings.warn(f'Use same attn_mask in all attentions in '
                      f'{
      
      self.__class__.__name__} ')
    else:
        assert len(attn_masks) == self.num_attn, f'The length of ' \
                    f'attn_masks {
      
      len(attn_masks)} must be equal ' \
                    f'to the number of attention in ' \
                    f'operation_order {
      
      self.num_attn}'

    for layer in self.operation_order:                  # 遍历config文件的顺序
        if layer == 'self_attn':
            temp_key = temp_value = query 
            query = self.attentions[attn_index](        # 内部调用nn.MultiHeadAttn
                query,
                temp_key,
                temp_value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                    # 若有位置编码信息则和query相加 
                key_pos=query_pos,                       # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=query_key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'norm':
            query = self.norms[norm_index](query)      # 层归一化
            norm_index += 1

        elif layer == 'cross_attn':                    # decoder用到
            query = self.attentions[attn_index](     
                query,
                key,
                value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                   # 若有位置编码信息则和query相加 
                key_pos=key_pos,                        # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'ffn':                         # 残差连接加全连接层
            query = self.ffns[ffn_index](
                query, identity if self.pre_norm else None)
            ffn_index += 1

    return query

 decoder部分和encoder流程类似,只是多了交叉注意力。

decoder=dict(
    type='DetrTransformerDecoder',
    return_intermediate=True,
    num_layers=6,
    transformerlayers=dict(
        type='DetrTransformerDecoderLayer',
        attn_cfgs=dict(
            type='MultiheadAttention',
            embed_dims=256,
            num_heads=8,
            dropout=0.1),
        feedforward_channels=2048,
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm')),
))

 我这里简单贴下nn.MultiHeadAttn内部流程:

attn_output_weights = torch.bmm(q, k.transpose(1, 2))  # 计算Q*K
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] # 判断一个tensor的shape是否等于某个尺寸,将其转成list。
# 利用attn_mask将未来的词遮挡住
if attn_mask is not None:
    if attn_mask.dtype == torch.bool:
        attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    else:
        attn_output_weights += attn_mask
# 借助key_padding_mask将pad部分遮挡住
if key_padding_mask is not None:
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # [2,8,5,5]
    attn_output_weights = attn_output_weights.masked_fill(
        key_padding_mask.unsqueeze(1).unsqueeze(2),
        float("-inf"),
    )
    attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

  上述代码流程比较简单,就是首先计算Q中每个元素和K的相似度,要依次用两种mask来遮挡住,为后续的softmax做准备。以cross attn为例,attn_output_weights是计算了每个真实单词和原始句子每个单词的相似性权重,所以要用和src_key_padding_mask一样的memory_key_padding_mask在行的维度上进行遮挡,故二者pad_mask是一致的

总结

  由于后续在detr上改进的论文对匈牙利算法以及loss计算改动不大,因此这部分代码就不讲解了。 感觉写的已经够乱了,哭脸。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

猜你喜欢

转载自blog.csdn.net/wulele2/article/details/123496514
今日推荐