Análisis del modelo del registro de aprendizaje del código DAB-DETR

DAB-DETR se perfecciona sobre la base de absorber Deformable-DETR, Conditional-DETR, Anchor-DETR, etc. Su principal contribución es inicializar la consulta a la forma de coordenadas de pensamiento x, y, w, h.
Esta publicación de blog analiza principalmente el trabajo realizado por DAB-DETR desde la perspectiva del código.
DAB-DETR mejora principalmente el modelo Decoder. El blogger analiza principalmente el modelo del módulo Decoder.

inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí

Ajuste del valor de temperatura codificado por posición

El primero es el archivo position_encoding.py, que redefine un PositionEmbeddingSineHWmétodo cuya función es separar los valores de temperatura de ancho y alto de la parte de codificación de posición de alta frecuencia, de modo que el ancho y el alto tengan valores de temperatura diferentes. Este archivo también mejora el método de codificación de posición sincos y el método de codificación de posición aprendible.

class PositionEmbeddingSineHW(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperatureH = temperatureH
        self.temperatureW = temperatureW
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale
    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        # import ipdb; ipdb.set_trace()
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
        dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
        pos_x = x_embed[:, :, :, None] / dim_tx
        dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
        pos_y = y_embed[:, :, :, None] / dim_ty
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        # import ipdb; ipdb.set_trace()
        return pos

Arquitectura general del transformador

Primero entendamos la arquitectura general de Transformer:
Primero, echemos un vistazo a los parámetros pasados ​​por forward:
src: la información de características extraída por la columna vertebral, la forma es inicialmente antorcha. Tamaño ([2, 256,19,24] ) y luego se convierte en antorcha. Tamaño ([456, 2, 256])
máscara: Complete la información de la máscara de la imagen, la forma es inicialmente antorcha. Tamaño ([2, 19, 24]) y luego se aplana a antorcha. Tamaño ( [2, 456] )

refpoint_embed: codificación de coordenadas del punto de referencia, es decir, object_query , torch.Size([300, 4]). Se utiliza en el módulo Decoder, que se inicializa en la definición del módulo DAB-DETR: self.refpoint_embed = nn.Embedding(num_queries, query_dim), inicialmente torch.Size([300,4]), después de refpoint_embed = refpoint_embed.unsqueeze( 1 ).repeat(1, bs, 1) se convierte en torch.Size([300, 4]).

pos_embed: información de codificación de posición, la forma es inicialmente torch.Size([2, 256,19,24]) y luego se convierte en torch.Size([456, 2, 256]) El
código de ejecución del proceso anterior es el siguiente:

    # flatten NxCxHxW to HWxNxC
    bs, c, h, w = src.shape  #初始为2,256,19,24
    src = src.flatten(2).permute(2, 0, 1)#拉平:
    pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
    refpoint_embed = refpoint_embed.unsqueeze(1).repeat(1, bs, 1)
    mask = mask.flatten(1)  

Luego, los datos se envían al módulo codificador y la memoria de salida es: antorcha. Tamaño ([456, 2, 256])

 memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)

A continuación, inicialice tgt y juzgue su patrón de acuerdo con self.num_patterns, cuyo valor predeterminado es 0 aquí. tgt se inicializa a todo 0, y la forma es: torch.Size([300, 2, 256]), que es similar a DETR, que se usa como entrada inicial del decodificador.

 num_queries = refpoint_embed.shape[0]
if self.num_patterns == 0:
    tgt = torch.zeros(num_queries, bs, self.d_model, device=refpoint_embed.device)
else:
    tgt = self.patterns.weight[:, None, None, :].repeat(1, self.num_queries, bs, 1).flatten(0, 1) # n_q*n_pat, bs, d_model
    refpoint_embed = refpoint_embed.repeat(self.num_patterns, 1, 1) # n_q*n_pat, bs, d_model

Luego envíelo al módulo Decodificador:

 hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                      pos=pos_embed, refpoints_unsigmoid=refpoint_embed)
 return hs, references

Construcción del módulo codificador

El módulo codificador de DAB-DETR no es muy diferente de DETR.

EncoderLayer

src_mask=None
src_key_padding_mask: Complemente la imagen con una forma de [2, 456]
src: Las características extraídas por ResNet se convierten de bidimensionales a unidimensionales, y la forma es torch.Size([456, 2, 256])
pos: Información de codificación de posición, originalmente hay dos tipos, codificación de posición sincos y codificación de posición aprendible.Además, DAB-DETR también propone un método de codificación de posición que puede saltar ancho y alto. La forma es torch.Size([456, 2, 256])
src2obtenida a través de la autoatención, y la forma es torch.Size([456, 2, 256]), seguida de la capa de abandono y la capa de norma. El resultado de salida final es src: torch.Size([456, 2, 256]), que se envía a Decoder.

q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src

Al igual que DETR, with_pos_embedes una adición directa.

def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

Módulo codificador

El codificador consta de 6 capas de codificador.

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None, d_model=256):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.query_scale = MLP(d_model, d_model, d_model, 2)
        self.norm = norm
    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
        for layer_id, layer in enumerate(self.layers):
            # rescale the content and pos sim
            pos_scales = self.query_scale(output)
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos*pos_scales)
        if self.norm is not None:
            output = self.norm(output)
        return output

Breve resumen del decodificador

En la parte del decodificador query ancor(cuadros de anclaje), se inicializa a [2, 300, 4] y pasará Anchor Sine Encoding, x, y, w, h se convertirán a 128 dimensiones, 4 es 512 dimensiones, y luego pasará una MLPconversión para 256.
El método de codificación de posición es el siguiente: un total de 4, la dimensión codificada es de 128 dimensiones.

inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí

Una de sus principales novedades es la siguiente, se agrega el mecanismo de atención de modulación de ancho y alto, la razón de esto es hacer que la atención sea más sensible al ancho y al alto.

inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí

Implementación del código del módulo decodificador

Primero, proporcione tgtel valor de output, como se puede ver aquí, el resultado de salida es outputy su forma es torch.Size([300, 2, 256])

output = tgt

Se reference_pointsnormalizará, la forma sigue siendo antorcha. Tamaño ([300, 2, 4])

reference_points = refpoints_unsigmoid.sigmoid()

Después de ingresar al bucle del decodificador, primero reference_pointscodifique la posición de alta frecuencia, es decir, saque todos los valores, ingrese el módulo de codificación de posición de alta frecuencia y cambie de antorcha. Tamaño ([300, 2, 4]) a antorcha. Tamaño ([300, 2, 512]), cada uno se convierte en 128, de la siguiente manera:
inserte la descripción de la imagen aquí
Luego, después de un self.ref_point_head(MLP)cambio a antorcha. Tamaño ([300, 2, 256])

obj_center = reference_points[..., :self.query_dim]  #torch.Size([300, 2, 4])  
query_sine_embed = gen_sineembed_for_position(obj_center) #torch.Size([300,2,512])
query_pos = self.ref_point_head(query_sine_embed) #torch.Size([300, 2, 256])

gen_sineembed_for_positionMétodos de la siguiente manera:

def gen_sineembed_for_position(pos_tensor):
    # n_query, bs, _ = pos_tensor.size()
    # sineembed_tensor = torch.zeros(n_query, bs, 256)
    scale = 2 * math.pi
    dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000 ** (2 * (dim_t // 2) / 128)
    x_embed = pos_tensor[:, :, 0] * scale
    y_embed = pos_tensor[:, :, 1] * scale
    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
    if pos_tensor.size(-1) == 2:
        pos = torch.cat((pos_y, pos_x), dim=2)
    elif pos_tensor.size(-1) == 4:
        w_embed = pos_tensor[:, :, 2] * scale
        pos_w = w_embed[:, :, None] / dim_t
        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
        h_embed = pos_tensor[:, :, 3] * scale
        pos_h = h_embed[:, :, None] / dim_t
        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
    else:
        raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
    return pos

Luego se realiza alguna inicialización self.query_scaley MLPla salida se puede considerar como el resultado de salida de la capa anterior de Decoder.
Saque query_sine_embedlas primeras 256 dimensiones, es decir, pos_transformationmultiplique x, y por (1 en la primera capa).
ref_anchor_heades un MLP con self.ref_anchor_head = MLP(d_model, d_model, 2, 2)una dimensión de entrada de 256, un ancho de capa intermedia de 256, una dimensión de salida de 2 y una capa oculta de 2.
refHW_cond es torch.Size([300, 2, 2])
query_sine_embed es inicialmente torch.Size([300, 2, 512]), después de los siguientes query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformationcambios a torch.Size([300, 2, 256]), esta oración El código significa tomar las primeras 256 dimensiones

if self.query_scale_type != 'fix_elewise':#执行
      if layer_id == 0:#第一层时执行
          pos_transformation = 1
      else:
          pos_transformation = self.query_scale(output) #query_scale为MLP
else:
     pos_transformation = self.query_scale.weight[layer_id]
#取出  query_sine_embed的前256维,即x,y与pos_transformation相乘 
query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation
if self.modulate_hw_attn:
      refHW_cond = self.ref_anchor_head(output).sigmoid() #将其送入MLP后进行归一化 torch.Size([300, 2, 2])
      query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
      query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)

El código anterior en realidad ejecuta el siguiente proceso: Tenga en cuenta que no es que PE(Xref) y PE(Yref) no se multipliquen en este momento, sino porque está establecido en 1, es decir, aquí podemos verlo en el segundo capa de DecoderLayer pos_transformation = 1.

inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí

Luego envíe los datos a DecoderLayer, tenga en cuenta que DecoderLayer es la primera capa en este momento.

output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                           is_first=(layer_id == 0))

El módulo DecoderLayer de primera capa

Auto_Atención

Primero, calcule el mecanismo de autoatención en DecoderLayer.
Echemos un vistazo a cómo cambian los datos:
tgtes decir, la salida de la capa anterior de DecoderLayer es todo 0 en este momento, y la forma es torch.Size([300, 2, 256])
primero pasa a través de una capa lineal (sa_qcontent_proj = nn.Linear (d_model, d_model)) para obtener q_content
la forma como torch.Size([300, 2, 256])
Cabe señalar que tgtcuando se completa la inicialización de qkv a través de la capa lineal, aunque tgttodos son 0, q, k, v no son

inserte la descripción de la imagen aquí

Luego q_pos(la información xywh obtenida por Anchor a través de la codificación de posición de alta frecuencia y MLP) también pasa a través de una sa_qpos_projdimensión de capa lineal sin cambios: la forma es torch.Size([300, 2, 256])
y k, v también se inicializan en de la misma manera Igual que DETR, v no tiene información de posición.

Para resumir, en atención propia, el query_pos convertido desde Anchor Box proporciona información de posición, y la información de contenido se proporciona al inicializar a todo 0 o el resultado de salida del DecoderLayer anterior. La información de posición y la información de contenido también se agregan. Combinar juntos, por ejemplo: q = q_content + q_pos
en cuanto a lo siguiente, es exactamente lo mismo que DETR, solo ingrese q, k, v para participar en la operación.

 if not self.rm_self_attn_decoder:
            # Apply projections here
            # shape: num_queries x batch_size x 256
            q_content = self.sa_qcontent_proj(tgt)      # target is the input of the first decoder layer. zero by default.
            q_pos = self.sa_qpos_proj(query_pos)
            k_content = self.sa_kcontent_proj(tgt)
            k_pos = self.sa_kpos_proj(query_pos)
            v = self.sa_v_proj(tgt)
            num_queries, bs, n_model = q_content.shape
            hw, _, _ = k_content.shape
            q = q_content + q_pos
            k = k_content + k_pos
            tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
                                key_padding_mask=tgt_key_padding_mask)[0] 
                     #tgt2为Attention计算结果,torch.Size([300, 2, 256])
            # ========== End of Self-Attention =============
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)

Finalmente, se obtiene el tgt de salida de autoatención, y la forma es torch.Size([300, 2, 256]).El código anterior ejecuta la parte que se muestra en el cuadro a continuación.

inserte la descripción de la imagen aquí

Entonces está listo para ser introducido cross-attentionen el cálculo.

atención_cruzada

El primero es q k vel proceso de inicialización, se puede observar que q proviene de la salida de self-attention, y luego de una capa lineal, k y v provienen de la salida del Encoder. La dimensión de la memoria es torch.Size([456, 2, 256])

q_content = self.ca_qcontent_proj(tgt)#torch.Size([300, 2, 256])
k_content = self.ca_kcontent_proj(memory)#torch.Size([456, 2, 256])
v = self.ca_v_proj(memory)#torch.Size([456, 2, 256])

k_pos = self.ca_kpos_proj(pos)#对K进行位置编码,pos来自于Encoder。torch.Size([456, 2, 256])

Dado que es la primera capa, se deben realizar las siguientes operaciones, es decir, primero pase query_pos[torch.Size([300, 2, 256])] a través de una capa completamente conectada, y la dimensión no cambia, es decir, q_posel proceso de generar,

if is_first or self.keep_query_pos:#self.keep_query_pos默认为False
    q_pos = self.ca_qpos_proj(query_pos)# query_pos:torch.Size([300, 2, 256])
    q = q_content + q_pos
    k = k_content + k_pos
else:
    q = q_content
    k = k_content

El siguiente paso es Cross_Attentionel proceso de inicialización de Q, K y V enviados: Cabe señalar que la operación de atención separada se coloca fuera y originalmente se completó dentro de la atención.

q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)# q分头:torch.Size([300, 2, 8, 32])
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)#query_sine_embed即
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
#q经过拼接变为torch.Size([300, 2, 512])
k = k.view(hw, bs, self.nhead, n_model//self.nhead)#torch.Size([456, 2, 8, 32])
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)#torch.Size([456, 2, 8, 32])
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)#torch.Size([456, 2, 512])

Luego se enviará Q,K,Va Cross_Attention para el cálculo: Para resumir, q: antorcha.Tamaño([300, 2, 512]), k: antorcha.Tamaño([456, 2, 512]), v: antorcha.Tamaño( [456, 2, 256])

tgt2 = self.cross_attn(query=q, key=k, value=v, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]    

Específicamente ejecutar el siguiente proceso: diferentes dimensiones de QKV

return multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, out_dim=self.vdim)

Después de completar el cálculo de cross_attention, la dimensión de tgt2 es torch.Size([300, 2, 256]). Para cambios de dimensión, consulte la fórmula de cálculo de Atención:

inserte la descripción de la imagen aquí
Luego, después de una serie de conexiones residuales, la operación de normalización por lotes genera el resultado y el resultado final sigue siendo torch.Size([300, 2, 256]).

Estrategia de actualización de anclaje

Este módulo también es un punto de innovación de DAB-DETR, es decir, la estrategia de actualización de anclaje Anchor Update

Es decir, después del cálculo de atención cruzada de DecoderLayer, el valor de salida se pasa a la siguiente capa de DecoderLayer, y también se usa para la actualización del punto de anclaje, usando la red MLP para obtener el desplazamiento de x, y, w, h y la forma es antorcha.Tamaño ([300, 2, 4]). Agréguelo a nuestras coordenadas de punto de referencia inicializadas reference_points(es decir, cuadro de anclaje, la forma es antorcha. Tamaño ([300, 2, 4])). Esta es la estrategia de actualización del punto de anclaje, y el anclaje de inicialización en el modelo DETR anterior siempre permanece sin cambios.

inserte la descripción de la imagen aquí

if self.bbox_embed is not None:
    if self.bbox_embed_diff_each_layer:#是否共享参数:false
        tmp = self.bbox_embed[layer_id](output)
    else:
        tmp = self.bbox_embed(output)#经过MLP获得output偏移量x,y,w,h torch.Size([300, 2, 4])
    # import ipdb; ipdb.set_trace()
    tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
    new_reference_points = tmp[..., :self.query_dim].sigmoid()
    if layer_id != self.num_layers - 1:
        ref_points.append(new_reference_points)
    reference_points = new_reference_points.detach()
if self.return_intermediate:
    intermediate.append(self.norm(output))

Se puede ver en el código anterior que los puntos de referencia se actualizarán continuamente, es decir, la estrategia de actualización de Anchor
Para lograr la diferenciación automática, PyTorch rastrea todas las operaciones que involucran tensores y es posible que deba calcular gradientes para ellos (es decir, require_grad es verdad). Estas operaciones se registran como un gráfico dirigido. El método detach() construye una nueva vista en un tensor que se declara que no requiere gradientes. Lo que
ejecuta el código anterior es el siguiente proceso de encuadre:
inserte la descripción de la imagen aquí

El módulo DecoderLayer de segunda capa

En comparación con la primera capa de DecoderLayer, la estructura de la segunda capa es la misma que la de la primera capa, excepto que el tgt de inicialización del Decoder-Embedding de la primera capa es todo 0, y la segunda capa se convierte en la salida de la primera capa Además, debido a la estrategia de actualización de anclaje, los cuadros de anclaje de la segunda capa también se convierten en los cuadros de anclaje de la primera capa más el desplazamiento de xywh.

El primero es el cambio de Anchor Boxes, reference_points(es decir, Anchor Boxes) después de pasar por la capa anterior de Decoderlayer, el valor se ha actualizado, y después de la codificación de posición de alta frecuencia nuevamente, la capa MLP cambia la dimensión de datos a antorcha. Tamaño ([300, 2, 256])

obj_center = reference_points[..., :self.query_dim]  
query_sine_embed = gen_sineembed_for_position(obj_center)  
query_pos = self.ref_point_head(query_sine_embed) 

Inmediatamente después, la diferencia se destaca aquí. Primero, query_scale_typeel cambio en este momento cond_elewise, y debido a la segunda capa, la salida (es decir, el resultado de la salida de la capa anterior) pasa

self.query_scale = MLP(d_model, d_model, d_model, 2)

La codificación para obtener pos_transformationla dimensión es torch.Size([300, 2, 256])

inserte la descripción de la imagen aquí

A continuación query_sine_embed[...,:self.d_model] * pos_transformation, el query_sine_embed aquí es torch.Size([300, 2, 512]), que toma las 256 dimensiones anteriores, es decir, la adquisición correspondiente es x, y. Multiplica con pos_transformation, pos_transformationes decir Xref,Yref, lo que se hace aquí es la siguiente operación:
inserte la descripción de la imagen aquí

Se puede ver que no es que no haya multiplicación en la primera capa PE(Xref),PE(Yref), sino porque su valor es 1.
El proceso posterior es exactamente el mismo que el de la primera capa de DecodeLayer.

Módulo decodificador

Ejecutar inmediatamente después de finalizar el bucle de DecoderLayer: intermediateguardar el resultado de cada capa, que es una Lista que contiene 6 valores, y la forma de cada valor es torch.Size([300, 2, 256]), y luego desechar la sexta capa (operación emergente) y luego agregue el valor de salida final.

 if self.norm is not None:
        output = self.norm(output)
        if self.return_intermediate:
            intermediate.pop()
            intermediate.append(output)

inserte la descripción de la imagen aquí
Luego juzgue si bbox_embed (encabezado de predicción de cuadro MLP) es Ninguno, y torch.stack es una operación de empalme.

 if self.return_intermediate:
            if self.bbox_embed is not None:
                return [
                    torch.stack(intermediate).transpose(1, 2),
                    torch.stack(ref_points).transpose(1, 2),
                ]
            else:
                return [
                    torch.stack(intermediate).transpose(1, 2), 
                    reference_points.unsqueeze(0).transpose(1, 2)
                ]

El módulo decodificador de Transformer finalmente devuelve el resultado:

hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, refpoints_unsigmoid=refpoint_embed)

Entre ellos, referencias torch.Size([6, 2, 300, 4]), hs es torch.Size([6, 2, 300, 256]), el resultado es también el resultado de retorno de Decoder, las referencias están después de cada se completa la actualización de la caja. hs se considera información de características semánticas.

inserte la descripción de la imagen aquí

Módulo integral DAB-DETR

Después de Transformer, el siguiente paso es diseñar el cabezal de clasificación y el cabezal de regresión.
El primero es desnormalizar el valor de la referencia (Anchor Box) del módulo Decoder, y luego realizar la predicción de cabeza de regresión en hs (Valor de salida del Decoder, equivalente a la salida de DETR) para obtener tmp, cuya forma es torch.Size( [6, 2, 300, 4]), agregue este valor a la referencia procesada. (self.query_dim es 4, es decir, se suman todos), y finalmente se normaliza tmp.

   if not self.bbox_embed_diff_each_layer:#是否权值共享
        reference_before_sigmoid = inverse_sigmoid(reference)#反归一化
        tmp = self.bbox_embed(hs)#torch.Size([6, 2, 300, 4])
        tmp[..., :self.query_dim] += reference_before_sigmoid
        outputs_coord = tmp.sigmoid()

inserte la descripción de la imagen aquí
El valor de output_coord es el xywh del cuadro de predicción y
inserte la descripción de la imagen aquí
el resultado final es:
pred_logits es la predicción de categoría (aquí hay 91 categorías) torch.Size([2, 300, 91])
pred_boxes es la predicción de cuadro box torch.Size( [2, 300, 4 ])
aux_outputs es el resultado de las primeras 5 capas de DecoderLayer. Es una lista con 5 valores.
inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/pengxiang1998/article/details/130208479
Recomendado
Clasificación