Detr source code interpretation of mmdetection

foreword

 This article first briefly introduces the principle of the Detr paper, and then introduces the implementation of detr in mmdetection.
 Paper address: chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.53yu.com/pdf/2005.12872.pdf,

1. A brief introduction to the principle

insert image description here
 Overall process: After an input image is given, 1) Feature vector extraction: First, the last layer feature map F of the image is extracted through ResNet.Note that only one layer of feature maps is used here because of subsequent computational complexity. In addition, since only the last layer of feature maps is used, it is not friendly to small target detection, which is also the reason for the subsequent improvement of deformable detr.2) Add position encoding information: Flatten F into a one-dimensional tensor and add position encoding information to get I. 3) The encoder part of the Transformer 4) The decoder part of the Transformer, learning position embedding object queries. 5) FFN part: 6) Subsequent Hungarian matching + loss calculation.

2. Source code introduction in mmdetection

2.1. Overall logic

  The internal logic of Detr is as follows: in mmdet/models/detector/single_stage.py. That is, the image feature vector is first extracted, and then the final loss is calculated through DetrHead.

def forward_train(self,
                  img,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    x = self.extract_feat(img) # 提取图像特征向量  
    # 经过DetrHead得到loss                   
    losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                          gt_labels, gt_bboxes_ignore)
    return losses

2.2. Image feature vector extraction

 The config configuration file for extracting image feature vectors in mmdet is as follows. It can be found that ResNet50 is used and only the last layer of feature layers is extracted, that is, out_indices=(3,). For the internal principle, please refer to my blog post: Introduction to the backbone of mmdet.

backbone=dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(3, ),     # detr仅要resnet50的最后一层特征图,并不需要FPN
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

2.3. Add position encoding information to the image feature vector

 This part of the code comes from mmdet/models/dense_heads/detr_head.py.
  The generation of position coding information in mmdet is based on the mask matrix, the so-calledThe mask is to pad the image in order to unify the batch size. The filled part should be discarded in the subsequent calculation of multi-head attention, so a mask matrix is ​​needed to cover it, the specific shape is [batch, h, w] Here first paste the process of generating the mask:

batch_size = x.size(0)   
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
masks = x.new_ones((batch_size, input_img_h, input_img_w))  # [b,838,768]
for img_id in range(batch_size):
    img_h, img_w, _ = img_metas[img_id]['img_shape']    # 创建了一个mask,非0代表无效区域, 0 代表有效区域
    masks[img_id, :img_h, :img_w] = 0                   # 将pad部分置为1,非pad部分置为0.

 I simply paste the schematic diagram of the mask here:
insert image description here
  on the basis of the mask [batch, 256, h, w], note that the hw at this time is the size of the original image; and the hw of the input image has changed after resnet50 downsampling, so It is necessary to further downsample the mask into the same shape as the image feature vector. code show as below:

# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
    masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,27,24]

 Subsequently, the position encoding part (mmdet/models/utils/position_encoding.py) can be generated. This function generates a 256-dimensional unique position vector for each pixel position of the masks. I simply wrote a test script:

import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils.positional_encoding import POSITIONAL_ENCODING  # 加载注册器

positional_encoding = dict(type='SinePositionalEncoding', num_feats=128, normalize=True)
self = build_positional_encoding(positional_encoding)
self.eval()
mask = torch.tensor([[[0,0,1],[0,0,1],[1,1,1]]], dtype= torch.uint8) # [1,3,3]
out = self(mask)          # [b,256,h,w]

 If you are interested, you can take a look at the implementation logic of mmdet's position encoding (just a simple comment):

def forward(self, mask):
    """Forward function for `SinePositionalEncoding`.

    Args:
        mask (Tensor): ByteTensor mask. Non-zero values representing
            ignored positions, while zero values means valid positions
            for this image. Shape [bs, h, w].

    Returns:
        pos (Tensor): Returned position embedding with shape
            [bs, num_feats*2, h, w].
    """
    # For convenience of exporting to ONNX, it's required to convert
    # `masks` from bool to int.
    mask = mask.to(torch.int)
    not_mask = 1 - mask       # 取反将1的位置视为图像区域
    y_embed = not_mask.cumsum(1, dtype=torch.float32)  # 累加1得到y方向坐标 [h]
    x_embed = not_mask.cumsum(2, dtype=torch.float32)  # 累加1得到x方向坐标 [w]
    # 归一化过程就是除以坐标中的max,而最后一行/列就是累加的最大的向量
    if self.normalize:
        y_embed = (y_embed + self.offset) / \
                  (y_embed[:, -1:, :] + self.eps) * self.scale # 取最后一行
        x_embed = (x_embed + self.offset) / \
                  (x_embed[:, :, -1:] + self.eps) * self.scale # 取最后一列
    # 创建一个[128]的特征向量
    dim_t = torch.arange(
        self.num_feats, dtype=torch.float32, device=mask.device)
    dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)  # 归一化的 [0,0,1,1,2,2,...,64,64]乘温度系数
    # 行列坐标分别除以dim_t得到每个点的128维的行列特征向量
    pos_x = x_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    pos_y = y_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    # use `view` instead of `flatten` for dynamically exporting to ONNX
    B, H, W = mask.size()
    # 分别采样奇数和偶数位置并执行sin和cos,并拼接[b,h,w,128]
    pos_x = torch.stack(
        (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    pos_y = torch.stack(
        (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    # 最后将横纵坐标拼接得到每个点唯一的256维度的位置向量
    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)     # [b,256,h,w]
    return pos

3. Send to Transformer

3.1. Overall logic

  After obtaining the image feature vector x=[b,c,h,w], masks[b,h,w] matrix and position code pos_embed[b,256,h,w], it can be sent to Transformer, the key is to clarify What do the QKV of encoder and decoder refer to, see the code:

memory = self.encoder(
    query=x,                     # [hw,b,c]
    key=None,
    value=None,
    query_pos=pos_embed,         # [hw,b,c]
    query_key_padding_mask=mask) # [b,hw]
target = torch.zeros_like(query_embed)   # decoder初始化全0
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
    query=target,                # 全0的target, 后续在MultiHeadAttn中执行了
    key=memory,                  # query = query + query_pos又加回去了。
    value=memory,
    key_pos=pos_embed,
    query_pos=query_embed,       # [num_query, bs, dim]
    key_padding_mask=mask)
out_dec = out_dec.transpose(1, 2)

  Among them, q in the encoder is x, kv is None, query_pos represents position encoding, and query_key_padding_mask is mask. decoderq is the target of all 0s, and the subsequent decoder will iteratively update q, and kv is memory, that is, the output of the encoder; key_pos is still the position information of k; query_embed is the Object query in the paper, which can learn the position information; key_padding_mask is still the mask.

3.2. Encoder part

  Let's first look at the encoder initialization part. The internal loop calls BaseTransformerLayer 6 times, so we only need to explain one layer of EncoderLayer.

encoder=dict(
    type='DetrTransformerEncoder',
    num_layers=6,                        # 经过6层Layer
    transformerlayers=dict(              # 每层layer内部使用多头注意力
        type='BaseTransformerLayer',
        attn_cfgs=[
            dict(
                type='MultiheadAttention',
                embed_dims=256,           
                num_heads=8,
                dropout=0.1)
        ],
        feedforward_channels=2048,        # FFN中间层的维度   
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程

 Let's look at the forward part of BaseTransformerLayer,This part can lose the core part of detr, because essentially mmdet just encapsulates pytorch's existing nn.MultiHeadAtten function. Therefore, it is necessary to understand the meaning of the two mask parameters in nn.MultiHeadAttn, due to space limitations, here you can refer to nn.Transformer to understand these two masks.But the simple understanding is: attn_mask is not used in detr, only key_padding_mask is used. attn_mask is used to block future text information, and the image can see all the information, so attn_mask is not needed.

def forward(self,
            query,
            key=None,
            value=None,
            query_pos=None,
            key_pos=None,
            attn_masks=None,
            query_key_padding_mask=None,
            key_padding_mask=None,
            **kwargs):
    """Forward function for `TransformerDecoderLayer`.

    **kwargs contains some specific arguments of attentions.

    Args:
        query (Tensor): The input query with shape
            [num_queries, bs, embed_dims] if
            self.batch_first is False, else
            [bs, num_queries embed_dims].
        key (Tensor): The key tensor with shape [num_keys, bs,
            embed_dims] if self.batch_first is False, else
            [bs, num_keys, embed_dims] .
        value (Tensor): The value tensor with same shape as `key`.
        query_pos (Tensor): The positional encoding for `query`.
            Default: None.
        key_pos (Tensor): The positional encoding for `key`.
            Default: None.
        attn_masks (List[Tensor] | None): 2D Tensor used in
            calculation of corresponding attention. The length of
            it should equal to the number of `attention` in
            `operation_order`. Default: None.
        query_key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_queries]. Only used in `self_attn` layer.
            Defaults to None.
        key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_keys]. Default: None.

    Returns:
        Tensor: forwarded results with shape [num_queries, bs, embed_dims].
    """
    norm_index = 0
    attn_index = 0
    ffn_index = 0
    identity = query
    if attn_masks is None:
        attn_masks = [None for _ in range(self.num_attn)]
    elif isinstance(attn_masks, torch.Tensor):
        attn_masks = [
            copy.deepcopy(attn_masks) for _ in range(self.num_attn)
        ]
        warnings.warn(f'Use same attn_mask in all attentions in '
                      f'{
      
      self.__class__.__name__} ')
    else:
        assert len(attn_masks) == self.num_attn, f'The length of ' \
                    f'attn_masks {
      
      len(attn_masks)} must be equal ' \
                    f'to the number of attention in ' \
                    f'operation_order {
      
      self.num_attn}'

    for layer in self.operation_order:                  # 遍历config文件的顺序
        if layer == 'self_attn':
            temp_key = temp_value = query 
            query = self.attentions[attn_index](        # 内部调用nn.MultiHeadAttn
                query,
                temp_key,
                temp_value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                    # 若有位置编码信息则和query相加 
                key_pos=query_pos,                       # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=query_key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'norm':
            query = self.norms[norm_index](query)      # 层归一化
            norm_index += 1

        elif layer == 'cross_attn':                    # decoder用到
            query = self.attentions[attn_index](     
                query,
                key,
                value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                   # 若有位置编码信息则和query相加 
                key_pos=key_pos,                        # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'ffn':                         # 残差连接加全连接层
            query = self.ffns[ffn_index](
                query, identity if self.pre_norm else None)
            ffn_index += 1

    return query

 The decoder part is similar to the encoder process, but with more cross-attention.

decoder=dict(
    type='DetrTransformerDecoder',
    return_intermediate=True,
    num_layers=6,
    transformerlayers=dict(
        type='DetrTransformerDecoderLayer',
        attn_cfgs=dict(
            type='MultiheadAttention',
            embed_dims=256,
            num_heads=8,
            dropout=0.1),
        feedforward_channels=2048,
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm')),
))

attached

 Here I simply paste the internal process of nn.MultiHeadAttn:

attn_output_weights = torch.bmm(q, k.transpose(1, 2))  # 计算Q*K
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] # 判断一个tensor的shape是否等于某个尺寸,将其转成list。
# 利用attn_mask将未来的词遮挡住
if attn_mask is not None:
    if attn_mask.dtype == torch.bool:
        attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    else:
        attn_output_weights += attn_mask
# 借助key_padding_mask将pad部分遮挡住
if key_padding_mask is not None:
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # [2,8,5,5]
    attn_output_weights = attn_output_weights.masked_fill(
        key_padding_mask.unsqueeze(1).unsqueeze(2),
        float("-inf"),
    )
    attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

  The above code flow is relatively simple, that is, first calculate the similarity between each element in Q and K, and then use two masks to cover them in order to prepare for the subsequent softmax. Take cross attn as an example,attn_output_weights calculates the similarity weight between each real word and each word in the original sentence, so use the same memory_key_padding_mask as src_key_padding_mask to block in the row dimension, so the pad_masks of the two are consistent

Summarize

  Since the subsequent improved papers on detr have little change to the Hungarian algorithm and loss calculation, this part of the code will not be explained. I feel that the writing is messy enough, crying face. If you have any questions, please welcome +vx: wulele2541612007, and pull you into the group for discussion and exchange.

Guess you like

Origin blog.csdn.net/wulele2/article/details/123496514