Detailed understanding (study notes) | DETR (integrating Transformer's target detection framework) DETR entry interpretation and Transformer's practical implementation

I. Overview

DETR , full name DEtection TRansformer, is a Transformer-based end-to-end target detection network proposed by Facebook, published in ECCV2020.

Original text:
link
Source code:
link

The DETR end-to-end target detection network model is the first target detection framework model that successfully integrates Transformer as the central building block of the detection pipeline. Based on Transformers' end-to-end target detection, there is no NMS post-processing step, the real implementation does not use anchor, and it surpasses the latter compared with Faster RCNN.

On the COCO dataset, the comparison effect diagram is as follows:

The picture comes from the original text
As can be seen from the figure above, the effect of DETR is very good. DETR based on ResNet50 has achieved comparable effects to Faster-RCNN after various fine-tuning. At the same time, DETR has the best performance on large target detection, but it is slightly worse on small targets, and the match-based loss makes learning difficult to converge (that is, it is difficult to learn the optimal situation). The emergence of Deformable DETR has made relatively good improvements to these two problems.

The main building block of the new framework of Detection Transformer or DETR is an ensemble-based global loss function that enforces unique predictions via bipartite matching and transformer encoder-decoder architecture. Given a fixed small set of learned object queries, DETR considers the relationship between the target object and the global image context , and directly outputs the final set of predictions in parallel.

Unlike many other modern detectors, the new model is conceptually simple and does not require specialized libraries. DETR achieves comparable accuracy and runtime performance to the well-established and highly optimized Faster R-CNN baseline on the challenging COCO object detection dataset. Furthermore, DETR can be easily transferred to other tasks such as panoptic segmentation.

The Detection Transformer can predict the violent motion of all objects and is trained end-to-end by setting a loss function that performs binary matching between predicted objects and ground truth objects. DETR simplifies the detection pipeline by removing multiple hand-designed post-processing steps, such as nms, that encode prior knowledge components . Unlike most existing detection methods, DETR does not require any custom layers and thus can be easily replicated in any framework that includes standard CNN and transformer classes.

2. Transformer

Transformer, the original text , has been widely used since it was proposed. Its core is the superimposed use of the attention mechanism, which makes the AI ​​model selectively focus on certain parts of the input, so the reasoning is more efficient. Not only has it achieved remarkable results in the field of NLP, but it is now being misappropriated into the field of CV. It is still essentially an Encoder-Decoder structure. Both the encoder and the decoder are a multiple superposition of Self-Attention modules. Through the form of encoding-decoding, it is possible to learn and acquire important features with multiple multi-head attention, and then combine the two to achieve Integrate the "context word order information" of the image to perform better target detection. The module structure is shown in the figure below:

insert image description here
insert image description here
Encoder module:
insert image description here

Compared with traditional sequence models such as RNN, Transformer mainly improves on:

  1. Turn RNN into a superposition of multiple self-attention Self-Attention structures
  2. Compute the correlation of any element in the sequence relative to all other elements in parallel, efficiently extract the correlation in the context, and introduce the Multi-head Attention mechanism to extract features from multiple perspectives .
  3. The position code is used to describe the front and rear information of the sequence , which replaces the RNN serial calculation process.

Transformer's pytorch implementation

Use the pytorch interface to show the actual usage of Transformer as follows in this article
Transformer combat

The encapsulation of Transformer by pytorch is realized through torch.nnTransformer, which mainly includes the following parameters:

torch.nn.Transformer(d_model: int = 512,nhead: int = 8,num_encoder_layers: int = 6,num_decoder_layers: int = 6,dim_feedforward: int = 2048,dropout: float = 0.1,activation: str = 'relu',custom_encoder: Optional[Any] = None,custom_decoder: Optional[Any] = None)

Among them,
d_model is the number of channels of word embedding,
n_head is the number of multi-head attention heads,
num_encoder_layers and num_decoder_layers correspond to the number of self-attention modules of the encoder and decoder, respectively,
dim_feedforward corresponds to the Linear in the encoder-decoder The dimension of the layer.

The forward function of nn.Transformer implements the process of encoding and decoding:

forward(src: torch.Tensor,tgt: torch.Tensor,src_mask: Optional[torch.Tensor] = None,tgt_mask: Optional[torch.Tensor] = None,memory_mask: Optional[torch.Tensor] = None,src_key_padding_mask: Optional[torch.Tensor] = None,tgt_key_padding_mask: Optional[torch.Tensor] = None,memory_key_padding_mask: Optional[torch.Tensor] = None)→ torch.Tensor

Among them, the two parameters that must be input are src and tgt, which correspond to the input inputs of the encoder and the input outputs of the decoder respectively. The role of tgt is similar to a conditional constraint. The tgt input of the first layer of Decoder is a word embedding vector, and the calculation result of the previous layer is from the second layer.

Among other optional parameters, [src/tgt/memory]_mask is a mask array, which defines the strategy for calculating Attention, corresponding to Section 3.1 of the original text. A popular explanation is: in a word sequence, each word can only be affected by the words before it, so all positions behind the word need to be ignored , so when calculating Attention, the word vector and the words behind it The correlation of the vectors is 0. ( However, in fact, each word, especially in Chinese, should be related to contextual semantics and word order, so as to better learn the specific meaning of the word. )

[src, tgt, memory]_key_padding_mask is also a mask array, which defines which positions in src, tgt and memory need to be reserved and which need to be ignored.

3. DETR

The idea of ​​DETR is similar to the essential idea of ​​traditional target detection, but the way of expression is very different. Traditional methods such as the Anchor-based method essentially classify the categories of predefined dense anchors and regress the frame coefficients. DETR regards target detection as a set prediction problem (sets and anchors have similar functions). Since Transformer is essentially a sequence conversion function, DETR can be regarded as a conversion process from an image sequence to a collection sequence. This collection is actually a learnable positional encoding (also called object queries or output positional encoding in the article , and query_embed in the code ).

DETR's network structure diagram (algorithm flow):
insert image description here
Transformer structure used by DETR:
insert image description here
spatial positional encoding is a two-dimensional spatial positional encoding method proposed by the author. The positional encoding is added to the self attention of the encoder and the cross attention of the decoder respectively . At the same time, the object The queries are also added to the two attentions of the decoder . The original Transformer added positional encoding to input and output embedding. It is worth mentioning that the author pointed out in the ablation experiment that even without adding any position code to the encoder, the final AP is only 1.3 points lower than the complete DETR.

The code rewrites the TransformerEncoderLayer and TransformerDecoderLayer classes based on PyTorch. The only PyTorch interface used is the nn.MultiheadAttention class. The source code requires PyTorch 1.5.0 or higher.

The code core is located in models/transformer.py and models/detr.py .

Transformer.py

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        mask = mask.flatten(1)

        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

The Transformer class contains an Encoder and a Decoder object. The implementation of related classes can be found in transformer.py. Focus on the forward function, there is a transformation operation on the input tensor : # flatten NxCxHxW to HWxNxC.

Combined with the shape definitions of src and tgt in PyTorch, it can be found that the idea of ​​**DETR is to expand the pixels of the backbone output feature map into one-dimensional and take it as the sequence length, while the definitions of batch and channel remain unchanged. **So DETR can calculate the correlation of each pixel of the feature map relative to all other pixels,This is achieved in CNN by relying on the receptive field. It can be seen that Transformer can capture a larger receptive range than CNN.

DETR does not use masked attention when calculating attention , because after the feature map is expanded into one dimension, all pixels may be related to each other , so there is no need to specify the mask. And src_key_padding_mask is used to remove the part of zero_pad.

There are two key variables pos_embed and query_embed in the forward function . where pos_embed is position encoding, located in models/position_encoding.py .

position_encoding.py

According to the characteristics of the two-dimensional feature map, DETR implements its own two-dimensional position encoding method. The implementation code is as follows:

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

Inside, mask is an array of position masks. For an image that has not passed through zero_pad, its mask is an array of all 0s.

Comparing the code, it can be seen that DETR calculates a position code for the x and y directions of the two-dimensional feature map, and the position code length of each dimension is num_pos_feats (this value is actually half of hidden_dim), for x or y , Calculate the sine of the odd position, calculate the cosine of the even position, and then concatenate pos_x and pos_y to get an NHWD array, and then pass permute (0,3,1,2), the shape becomes NDHW , where D is equal to hidden_dim . This hidden_dim is the dimension of the Transformer input vector . In terms of implementation, it must be equal to the dimension of the feature map output by the CNN backbone. So the shape of pos code and CNN output feature is exactly the same.

src = src.flatten(2).permute(2, 0, 1)         
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)

Perform flatten and permute operations on the features and pos code output by CNN to change the shape into SNE , which conforms to the input shape definition of PyTorch. In TransformerEncoder, src and pos_embed are added. You can view the code yourself.

detr.py

class DETR

This class encapsulates the entire calculation process of DETR . First, let's look at what the object queries mentioned repeatedly in the paper are.

The answer is query_embed .

In the code, query_embed is actually an embedding array:

self.query_embed = nn.Embedding(num_queries, hidden_dim)

Among them, num_queries is the number of predefined target queries, which is 100 by default in the code. Its meaning is: **According to the characteristics of Encoder encoding, Decoder converts 100 queries into 100 targets. **Usually 100 queries are enough. Few images can contain more than 100 targets (unless super-intensive tasks). In contrast, the number of anchors to be predicted by CNN-based methods is tens of thousands, and the calculation cost is really very big.

Transformer 's forward function defines an array target with the same shape as query_embed and all 0s , and then adds query_embed and target in the forward of TransformerDecoderLayer ( here , the role of query_embed is similar to that of position encoding), as in self attention query and key ; as query in multi-head attention :

class TransformerDecoderLayer(nn.Module):
    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

After the object queries are calculated by the decoder , an array of shape TNE will be output , where T is the sequence length of the object queries , that is, 100, N is the batch size, and E is the feature channel.

Finally, the class prediction is output through a Linear layer , and the box prediction is output through a multi-layer perceptron structure :

self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)

#forward
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()
out = {
    
    'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}

The channel of the classification output is num_classes+1, the category starts from 0, and the background category is num_classes.

class SetCriterion

This class is responsible for the calculation of loss.

The CNN-based method will calculate the prediction result of each anchor , and then use the iou calculation between the prediction result and the ground truth box , and select those anchors whose iou is greater than a certain threshold as positive samples to return their class and box deltas . Similarly, DETR will also calculate the prediction of each object query , but DETR will directly calculate the normalized values ​​​​of the four corners of the box, rather than based on box deltas:

Then perform binary matching between these object predictions and the ground truth box . DETR uses the Hungarian algorithm to complete this matching process.

Complete flow chart:
insert image description here
If there are N targets, then N out of 100 object predictions will be able to match the N ground truth , and the others will be successfully matched with "no object", and the category label of these predictions will be assigned It is num_classes , which means that the prediction is the background.

This design is very good, it is a highlight in DETR, and it is also one of its characteristics, so that in theory, each object query has a unique matching target, and there will be no overlap, so DETR does not need nms for post-processing.

It calculates the loss function loss according to the matching result. The calculation formula of the loss function is no longer listed here.

class SetCriterion(nn.Module):
    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {
    
    k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {
    
    }
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

Match outputs_without_aux and targets through self.matcher . The Hungarian algorithm will return an indices tuple , which contains the index of src and target . For the specific matching process, please refer to models/matcher.py.

classification loss

Classification loss uses cross entropy loss for all predictions

def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
    src_logits = outputs['pred_logits']

    idx = self._get_src_permutation_idx(indices)
    target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
    target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
    target_classes[idx] = target_classes_o

    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
    losses = {
    
    'loss_ce': loss_ce}

    return losses

target_classes_o is obtained according to the target index to obtain all the matching truth classes , and put them into the corresponding position of target_classes according to the src index . Predictions that fail to match are filled in target_classes with self.num_classes . The function of the function _get_src_permutation_idx is to obtain the batch index of src and the corresponding match index from the indices tuple .

box loss

The box loss uses l1 loss and giou loss for predictions that match successfully

def loss_boxes(self, outputs, targets, indices, num_boxes):
    """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
       targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
       The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
    """
    assert 'pred_boxes' in outputs
    idx = self._get_src_permutation_idx(indices)
    src_boxes = outputs['pred_boxes'][idx]
    target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

    loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

    losses = {
    
    }
    losses['loss_bbox'] = loss_bbox.sum() / num_boxes

    loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
        box_ops.box_cxcywh_to_xyxy(src_boxes),
        box_ops.box_cxcywh_to_xyxy(target_boxes)))
    losses['loss_giou'] = loss_giou.sum() / num_boxes
    return losses

target_boxes is the truth box of all successful matches obtained by target index , src_boxes is the predictions of successful match obtained by src index , and the l1_loss and giou loss between them are calculated .

Application of DETR on Panoramic Segmentation (shallow view)

Adding a mask head to each object embedding of the Decoder can realize the function of pixel-level segmentation. The mask head can be jointly trained with the box embed , or the mask head can be trained separately after the box embed is trained .

DETR 's approach is similar to Mask-RCNN , which predicts the segmentation of the box corresponding to the instance based on the given box prediction . Here DETR upsamples the attention map output by the mask head , adds it to some branches of the backbone , and implements an FPN function, and then performs a bitwise argmax operation on the bold style of the mask map corresponding to all boxes to obtain the final segmentation picture.

Finally (personal opinion)

The application of the Transformer structure in the CV field shows the close relationship between the two computer AI fields of CV and NLP , and the mutual promotion relationship between the major sub-fields in the computer field. However, the huge effect of the Transformer structure on the semantic word order is still more significant in natural language processing, and it is applied to the field of large CV . I personally think that it is because of its special role in feature data such as time series, which makes it in the target It has achieved very good results on timing-related tasks such as detection and image understanding. However, perhaps subsequent improvements will make the two interlinked to get better results.

This article is only for learning and sharing, note recording, please contact to delete the infringement

Guess you like

Origin blog.csdn.net/qq_53250079/article/details/127457575