[Implementación del código] Interpretación del texto original DETR y detalles de implementación del código

1 descripción general del modelo

Insertar descripción de la imagen aquí

Macroscópicamente hablando, DETR consta principalmente de tres partes: la red troncal (CNN Backbone) basada en una red neuronal convolucional, la extracción de características y el interactor basado en TRM (Transformer) y el cabezal de clasificación y regresión basado en FFN, como como se muestra en la función build() en DETR. Lo más destacado de DETR es que abandona los procesos de procesamiento que no son de un extremo a otro, como NMS, generación de anclajes, etc., utiliza predicción de conjunto para modelar el proceso de detección de objetivos de un extremo a otro e introduce Transformer en el objetivo. detección, abriendo una nueva puerta al campo).

def build(args):

    backbone = build_backbone(args)
    transformer = build_transformer(args)

    model = DETR(
        backbone,# 骨干网
        transformer,# 重点部分
        num_classes=81,
        num_queries=100,# object query数量,作用相当于spatial embedding
        aux_loss=args.aux_loss)
    matcher = build_matcher(args)# 二分图匹配
    weight_dict = {
    
    'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    losses = ['labels', 'boxes', 'cardinality']
    postprocessors = {
    
    'bbox': PostProcess()}

    return model, criterion, postprocessors

2 Proceso básico de DETR

  • La columna vertebral de CNN extrae características de la imagen
  • Transformer Encoder mejora las funciones modelando relaciones globales con atención propia
  • La entrada de Transformer Decoder son consultas de objetos (incrustación espacial) y la salida del codificador Transformer (incrustación de contenido), que incluye principalmente los procesos de autoatención y atención cruzada. La atención propia interactúa principalmente con cada consulta, de modo que cada consulta puede ver qué están consultando otras consultas, para no repetirlas, similar a la función de NMS; la atención cruzada trata principalmente la consulta de objetos como una consulta, y el codificador Característica se utiliza como clave para consultar el área relacionada con la consulta.
  • Para la consulta generada por el decodificador, utilice FFN para extraer la información de posición y categoría del fotograma de destino.

Sigue el proceso básico anterior y comienza con el código para entender la idea original poco a poco. ¡Empecemos a continuación!

3 columna vertebral

La primera es la función para construir el módulo backbone.

def build_backbone(args):

    position_embedding = build_position_encoding(args)
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

3.1 Generar codificación de posición basada en el mapa de características

def build_position_encoding(args):
    N_steps = args.hidden_dim // 2
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif args.position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {
      
      args.position_embedding}")

    return position_embedding

Para la característica de entrada x, suponiendo que su tamaño es BxCxHxW, la codificación de posición requiere la codificación de posición en las dos dimensiones HW, por lo que hide_dim (C, canal) generalmente se divide en dos partes, una parte representa H y la otra parte representa W, y finalmente en el empalme se realiza en la dimensión del canal. La codificación de posición en DETR consiste principalmente en incrustación de posición sinusoidal y de aprendizaje. He escrito una implementación que puede aprender codificación posicional basada en TRM. Puedes probarla ejecutándola.

# 验证learnable pos embedding机制
x = torch.randn((8, 3, 32, 32))
h,w=x.shape[-2:]
row_embed,col_embed=nn.Embedding(50,256),nn.Embedding(50,256)

i,j=torch.arange(w,device=x.device),torch.arange(h,device=x.device)
x_emb,y_emb=col_embed(i),row_embed(j)
x_cat=x_emb.unsqueeze(0).repeat(h,1,1)
y_cat=y_emb.unsqueeze(1).repeat(1,w,1)

pos=torch.cat([x_cat,y_cat],dim=-1)
pos_learn=pos.permute(2,0,1).unsqueeze(0).repeat(x.shape[0],1,1,1)# shape:(8,512,32,32)

Construir columna vertebral

Esta parte no es demasiado difícil, los comentarios relevantes ya están escritos en el bloque de código.

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:# 是否返回中间层,在多尺度融合操作时会用到
            return_layers = {
    
    "layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:# 一般情况下只返回最后一层的输出
            return_layers = {
    
    'layer4': "0"}
        # IntermediateLayerGetter作用类似于Sequential,将多个神经层组合并可以指定返回中间输出
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)# 将数据传入网络,实例化网络得到输出,xs即为经过resnet四部分后的输出
        out: Dict[str, NestedTensor] = {
    
    }# 定义输出格式
        for name, x in xs.items():# 如果返回中间层,out可以按照name存储,返回最后一层则只有layer4
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# 利用插值法产生不同尺度下的mask
            out[name] = NestedTensor(x, mask)
        return out

Combine los dos anteriores en consecuencia

class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)# [0]代表backbone的输出
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))# [1]代表position_embedding

        return out, pos# 返回抽取后的特征及对应的位置编码

Transformador

Insertar descripción de la imagen aquí
La parte TRM es en realidad exactamente la misma que la estructura del modelo de Atención, es todo lo que necesita. La parte diferente es solo la entrada de la parte del decodificador. En el transformador original, la entrada del decodificador es la incrustación codificada por la posición de fusión correspondiente a la secuencia objetivo. En este artículo, tgt inicializado a todos 0 se usa como secuencia objetivo y luego se fusiona query_embed . Una cosa que es muy fácil de confundir aquí es: la secuencia tgt todo cero es la incrustación de contenido. Query_embed en el código es la incrustación espacial que representa la posición de la colección de cuadros de destino. Dado que
la parte del codificador es principalmente extracción de características y Tiene poco impacto en el posicionamiento de límites, consideramos el decodificador. En la parte de atención cruzada, su entrada incluye principalmente tres partes: valor de clave de consulta

  • Consultas: cada consulta es el resultado de la primera capa de autoatención del decodificador ( consulta de contenido ) + consulta de objeto ( consulta espacial ). La consulta de objeto aquí es el concepto propuesto en DETR. Cada consulta de objeto es la información del cuadro candidato. Después de FFN, se puede generar información de ubicación y categoría (el número de consultas de objetos N en este artículo es 100)
  • Claves: cada clave se compone de la función de salida del codificador ( clave de contenido ) + codificación de posición ( clave espacial )
  • valores: solo salida del codificador
class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)

        self.d_model = d_model
        self.nhead = nhead

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)# flatten(k)表示将[k:n-1]拉平为一个维度
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# sine
        # num_queries x hidden_dim to num_queries x N x hidden_dim
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        # NxHxW to NxHW
        mask = mask.flatten(1)

        # decoder embedding,初始化为全0
        tgt = torch.zeros_like(query_embed)
        # encoder特征抽取,得到memory,shape同tgt
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        # decoder特征交互,得到hs
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

DETR

class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):

        super().__init__()
        self.num_queries = num_queries# object query nums
        self.transformer = transformer
        hidden_dim = transformer.d_model# 隐层维度
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)# 分类头,最后的类别为:类别数+背景
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)# 检测头,使用三层全连接层进行映射,最后投影到xywh
        self.query_embed = nn.Embedding(num_queries, hidden_dim)# object query
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)# 将得到的feature map通道归一
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):

        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        # 得到的feature可能是C3-C5几层,DETR只拿最后一层输入TRM
        src, mask = features[-1].decompose()
        assert mask is not None
        # self.transformer()[0]表示取dncoder的输出,序列1表示encoder输出
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {
    
    'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

Supongo que te gusta

Origin blog.csdn.net/weixin_43427721/article/details/132676577
Recomendado
Clasificación