Interpretación del código fuente de detr de mmdetection

prefacio

 Este artículo primero presenta brevemente el principio del documento Detr y luego presenta la implementación de detr en mmdetection.
 Dirección en papel: chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.53yu.com/pdf/2005.12872.pdf,

1. Una breve introducción al principio

inserte la descripción de la imagen aquí
 Proceso general: después de proporcionar una imagen de entrada, 1) Extracción del vector de características: primero, el mapa de características de la última capa F de la imagen se extrae a través de ResNet.Tenga en cuenta que aquí solo se utiliza una capa de mapas de características debido a la posterior complejidad computacional Además, dado que solo se usa la última capa de mapas de características, no es amigable para la detección de objetivos pequeños, que también es la razón de la mejora posterior de detr. deformable2) Agregue información de codificación de posición: aplane F en un tensor unidimensional y agregue información de codificación de posición para obtener I. 3) La parte del codificador del Transformador 4) La parte del decodificador del Transformador, aprendiendo la posición de incrustación de consultas de objetos. 5) Parte FFN: 6) Calculo húngaro subsiguiente de emparejamiento + pérdida.

2. Introducción del código fuente en mmdetection

2.1 Lógica general

  La lógica interna de Detr es la siguiente: en mmdet/models/detector/single_stage.py. Es decir, primero se extrae el vector de características de la imagen y luego se calcula la pérdida final a través de DetrHead.

def forward_train(self,
                  img,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    x = self.extract_feat(img) # 提取图像特征向量  
    # 经过DetrHead得到loss                   
    losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                          gt_labels, gt_bboxes_ignore)
    return losses

2.2 Extracción de vectores de características de imagen

 El archivo de configuración de configuración para extraer vectores de características de imagen en mmdet es el siguiente: se puede encontrar que se usa ResNet50 y solo se extrae la última capa de capas de características, es decir, out_indices=(3,). Para conocer el principio interno, consulte la publicación de mi blog: Introducción a la columna vertebral de mmdet.

backbone=dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(3, ),     # detr仅要resnet50的最后一层特征图,并不需要FPN
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

2.3 Agregar información de codificación de posición al vector de características de la imagen

 Esta parte del código proviene de mmdet/models/dense_heads/detr_head.py.
  La generación de información de codificación de posición en mmdet se basa en la matriz de máscara, la llamadaLa máscara es para rellenar la imagen con el fin de unificar el tamaño del lote, la parte llena debe descartarse en el cálculo posterior de atención multicabezal, por lo que se necesita una matriz de máscara para cubrirla., la forma específica es [batch, h, w] Aquí primero pegue el proceso de generación de la máscara:

batch_size = x.size(0)   
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
masks = x.new_ones((batch_size, input_img_h, input_img_w))  # [b,838,768]
for img_id in range(batch_size):
    img_h, img_w, _ = img_metas[img_id]['img_shape']    # 创建了一个mask,非0代表无效区域, 0 代表有效区域
    masks[img_id, :img_h, :img_w] = 0                   # 将pad部分置为1,非pad部分置为0.

 Simplemente pego el diagrama esquemático de la máscara aquí:
inserte la descripción de la imagen aquí
  sobre la base de la máscara [lote, 256, h, w], tenga en cuenta que el hw en este momento es el tamaño de la imagen original; y el hw de la imagen de entrada tiene cambió después de la reducción de resolución de resnet50, por lo que es necesario reducir aún más la resolución de la máscara en la misma forma que el vector de características de la imagen. el código se muestra a continuación:

# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
    masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,27,24]

 Posteriormente, se puede generar la parte de codificación de posición (mmdet/models/utils/position_encoding.py) Esta función genera un vector de posición único de 256 dimensiones para cada posición de píxel de las máscaras. Simplemente escribí un script de prueba:

import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils.positional_encoding import POSITIONAL_ENCODING  # 加载注册器

positional_encoding = dict(type='SinePositionalEncoding', num_feats=128, normalize=True)
self = build_positional_encoding(positional_encoding)
self.eval()
mask = torch.tensor([[[0,0,1],[0,0,1],[1,1,1]]], dtype= torch.uint8) # [1,3,3]
out = self(mask)          # [b,256,h,w]

 Si está interesado, puede echar un vistazo a la lógica de implementación de la codificación de posición de mmdet (solo un simple comentario):

def forward(self, mask):
    """Forward function for `SinePositionalEncoding`.

    Args:
        mask (Tensor): ByteTensor mask. Non-zero values representing
            ignored positions, while zero values means valid positions
            for this image. Shape [bs, h, w].

    Returns:
        pos (Tensor): Returned position embedding with shape
            [bs, num_feats*2, h, w].
    """
    # For convenience of exporting to ONNX, it's required to convert
    # `masks` from bool to int.
    mask = mask.to(torch.int)
    not_mask = 1 - mask       # 取反将1的位置视为图像区域
    y_embed = not_mask.cumsum(1, dtype=torch.float32)  # 累加1得到y方向坐标 [h]
    x_embed = not_mask.cumsum(2, dtype=torch.float32)  # 累加1得到x方向坐标 [w]
    # 归一化过程就是除以坐标中的max,而最后一行/列就是累加的最大的向量
    if self.normalize:
        y_embed = (y_embed + self.offset) / \
                  (y_embed[:, -1:, :] + self.eps) * self.scale # 取最后一行
        x_embed = (x_embed + self.offset) / \
                  (x_embed[:, :, -1:] + self.eps) * self.scale # 取最后一列
    # 创建一个[128]的特征向量
    dim_t = torch.arange(
        self.num_feats, dtype=torch.float32, device=mask.device)
    dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)  # 归一化的 [0,0,1,1,2,2,...,64,64]乘温度系数
    # 行列坐标分别除以dim_t得到每个点的128维的行列特征向量
    pos_x = x_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    pos_y = y_embed[:, :, :, None] / dim_t                         # [b,h,w,128]
    # use `view` instead of `flatten` for dynamically exporting to ONNX
    B, H, W = mask.size()
    # 分别采样奇数和偶数位置并执行sin和cos,并拼接[b,h,w,128]
    pos_x = torch.stack(
        (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    pos_y = torch.stack(
        (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
        dim=4).view(B, H, W, -1)
    # 最后将横纵坐标拼接得到每个点唯一的256维度的位置向量
    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)     # [b,256,h,w]
    return pos

3. Enviar a Transformador

3.1 Lógica general

  Después de obtener el vector de características de la imagen x=[b,c,h,w], la matriz de máscaras[b,h,w] y el código de posición pos_embed[b,256,h,w], se puede enviar a Transformer, la clave es para aclarar a que se refieren los QKV de codificador y decodificador, ver el código:

memory = self.encoder(
    query=x,                     # [hw,b,c]
    key=None,
    value=None,
    query_pos=pos_embed,         # [hw,b,c]
    query_key_padding_mask=mask) # [b,hw]
target = torch.zeros_like(query_embed)   # decoder初始化全0
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
    query=target,                # 全0的target, 后续在MultiHeadAttn中执行了
    key=memory,                  # query = query + query_pos又加回去了。
    value=memory,
    key_pos=pos_embed,
    query_pos=query_embed,       # [num_query, bs, dim]
    key_padding_mask=mask)
out_dec = out_dec.transpose(1, 2)

  Entre ellos, q en el codificador es x, kv es Ninguno, query_pos representa la codificación de posición y query_key_padding_mask es una máscara. descifradorq es el objetivo de todos los 0, y el decodificador subsiguiente actualizará iterativamente q, y kv es la memoria, es decir, la salida del codificador; key_pos sigue siendo la información de posición de k; query_embed es la consulta del objeto en el documento, que puede aprender la información de posición; key_padding_mask sigue siendo la máscara.

3.2 Parte del codificador

  Veamos primero la parte de inicialización del codificador. El bucle interno llama a BaseTransformerLayer 6 veces, por lo que solo necesitamos explicar una capa de EncoderLayer.

encoder=dict(
    type='DetrTransformerEncoder',
    num_layers=6,                        # 经过6层Layer
    transformerlayers=dict(              # 每层layer内部使用多头注意力
        type='BaseTransformerLayer',
        attn_cfgs=[
            dict(
                type='MultiheadAttention',
                embed_dims=256,           
                num_heads=8,
                dropout=0.1)
        ],
        feedforward_channels=2048,        # FFN中间层的维度   
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程

 Veamos la parte delantera de BaseTransformerLayer,Esta parte puede perder la parte central de detr, porque esencialmente mmdet simplemente encapsula la función nn.MultiHeadAtten existente de pytorch. Por lo tanto, es necesario comprender el significado de los dos parámetros de máscara en nn.MultiHeadAttn, debido a limitaciones de espacio, aquí puede consultar nn.Transformer para comprender estas dos máscaras.Pero el entendimiento simple es: attn_mask no se usa en detr, solo se usa key_padding_mask. attn_mask se usa para bloquear información de texto futura, y la imagen puede ver toda la información, por lo que no se necesita attn_mask.

def forward(self,
            query,
            key=None,
            value=None,
            query_pos=None,
            key_pos=None,
            attn_masks=None,
            query_key_padding_mask=None,
            key_padding_mask=None,
            **kwargs):
    """Forward function for `TransformerDecoderLayer`.

    **kwargs contains some specific arguments of attentions.

    Args:
        query (Tensor): The input query with shape
            [num_queries, bs, embed_dims] if
            self.batch_first is False, else
            [bs, num_queries embed_dims].
        key (Tensor): The key tensor with shape [num_keys, bs,
            embed_dims] if self.batch_first is False, else
            [bs, num_keys, embed_dims] .
        value (Tensor): The value tensor with same shape as `key`.
        query_pos (Tensor): The positional encoding for `query`.
            Default: None.
        key_pos (Tensor): The positional encoding for `key`.
            Default: None.
        attn_masks (List[Tensor] | None): 2D Tensor used in
            calculation of corresponding attention. The length of
            it should equal to the number of `attention` in
            `operation_order`. Default: None.
        query_key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_queries]. Only used in `self_attn` layer.
            Defaults to None.
        key_padding_mask (Tensor): ByteTensor for `query`, with
            shape [bs, num_keys]. Default: None.

    Returns:
        Tensor: forwarded results with shape [num_queries, bs, embed_dims].
    """
    norm_index = 0
    attn_index = 0
    ffn_index = 0
    identity = query
    if attn_masks is None:
        attn_masks = [None for _ in range(self.num_attn)]
    elif isinstance(attn_masks, torch.Tensor):
        attn_masks = [
            copy.deepcopy(attn_masks) for _ in range(self.num_attn)
        ]
        warnings.warn(f'Use same attn_mask in all attentions in '
                      f'{
      
      self.__class__.__name__} ')
    else:
        assert len(attn_masks) == self.num_attn, f'The length of ' \
                    f'attn_masks {
      
      len(attn_masks)} must be equal ' \
                    f'to the number of attention in ' \
                    f'operation_order {
      
      self.num_attn}'

    for layer in self.operation_order:                  # 遍历config文件的顺序
        if layer == 'self_attn':
            temp_key = temp_value = query 
            query = self.attentions[attn_index](        # 内部调用nn.MultiHeadAttn
                query,
                temp_key,
                temp_value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                    # 若有位置编码信息则和query相加 
                key_pos=query_pos,                       # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=query_key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'norm':
            query = self.norms[norm_index](query)      # 层归一化
            norm_index += 1

        elif layer == 'cross_attn':                    # decoder用到
            query = self.attentions[attn_index](     
                query,
                key,
                value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                   # 若有位置编码信息则和query相加 
                key_pos=key_pos,                        # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'ffn':                         # 残差连接加全连接层
            query = self.ffns[ffn_index](
                query, identity if self.pre_norm else None)
            ffn_index += 1

    return query

 La parte del decodificador es similar al proceso del codificador, pero con más atención cruzada.

decoder=dict(
    type='DetrTransformerDecoder',
    return_intermediate=True,
    num_layers=6,
    transformerlayers=dict(
        type='DetrTransformerDecoderLayer',
        attn_cfgs=dict(
            type='MultiheadAttention',
            embed_dims=256,
            num_heads=8,
            dropout=0.1),
        feedforward_channels=2048,
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm')),
))

adjunto

 Aquí simplemente pego el proceso interno de nn.MultiHeadAttn:

attn_output_weights = torch.bmm(q, k.transpose(1, 2))  # 计算Q*K
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] # 判断一个tensor的shape是否等于某个尺寸,将其转成list。
# 利用attn_mask将未来的词遮挡住
if attn_mask is not None:
    if attn_mask.dtype == torch.bool:
        attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    else:
        attn_output_weights += attn_mask
# 借助key_padding_mask将pad部分遮挡住
if key_padding_mask is not None:
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # [2,8,5,5]
    attn_output_weights = attn_output_weights.masked_fill(
        key_padding_mask.unsqueeze(1).unsqueeze(2),
        float("-inf"),
    )
    attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

  El flujo de código anterior es relativamente simple, es decir, primero calcule la similitud entre cada elemento en Q y K, y luego use dos máscaras para cubrirlos con el fin de prepararse para el subsiguiente softmax. Tome cross attn como ejemplo,attn_output_weights calcula el peso de similitud entre cada palabra real y cada palabra en la oración original, por lo tanto, use la misma memory_key_padding_mask que src_key_padding_mask para bloquear en la dimensión de la fila, de modo que pad_masks de los dos sean consistentes

Resumir

  Dado que los documentos mejorados posteriores sobre detr tienen pocos cambios en el algoritmo húngaro y el cálculo de pérdidas, esta parte del código no se explicará. Siento que la escritura es lo suficientemente desordenada, cara de llanto. Si tiene alguna pregunta, dé la bienvenida a +vx: wulele2541612007 y llévelo al grupo para discutir e intercambiar.

Supongo que te gusta

Origin blog.csdn.net/wulele2/article/details/123496514
Recomendado
Clasificación