[DETR source code analysis] 4. Loss calculation and post-processing module

foreword

Recently, I was looking at the source code of DETR. After watching it intermittently for about a week, I sorted out the main model code. I have been thinking about what form to write DETR's source code analysis. One form to consider is to write line by file like the YOLOv5 written before, and one is to string the source code according to functional modules. After thinking about it for a long time, I decided to use the second method. First, this method may save more time. In addition, it is also convenient for me to understand it as a whole.

I think looking at the code is to see that the entire model can be disassembled into functions, and finally all the modules are connected in series, so as to achieve twice the result with half the effort.

Another point I think is very important: get an open source project code, you must immediately configure the environment to run Debug normally, and immediately find the content related to the main model by analyzing train.py, and then focus on the analysis of the model, like some Logs, calculating mAP, drawing and other codes can be completely ignored, which can save a lot of time. Therefore, when I explain the source code in the future, I will completely strip out irrelevant codes, no longer explain, and focus on the model, improvement, loss and other content.

This section mainly talks about the loss calculation and post-processing part of DETR. It mainly involves three files: models/matcher.py, models/detr.py and engine.py.

Github annotation version source code: HuKai97/detr-annotations

1. Loss calculation: SetCriterion

First, the loss function will be defined in detr.py:

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=args.eos_coef, losses=losses)
criterion.to(device)

Then call the criterion function after the forward reasoning in the train_one_epoch of engine.py to calculate the loss:

# 前向传播
outputs = model(samples)
# 计算损失  loss_dict: 'loss_ce' + 'loss_bbox' + 'loss_giou'    用于log日志: 'class_error' + 'cardinality_error'
loss_dict = criterion(outputs, targets)
# 权重系数 {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
weight_dict = criterion.weight_dict   
# 总损失 = 回归损失:loss_bbox(L1)+loss_bbox  +   分类损失:loss_ce
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

Well, let's focus on the SetCriterion class:

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes     # 数据集类别数
        self.matcher = matcher             # HungarianMatcher()  匈牙利算法 二分图匹配
        self.weight_dict = weight_dict     # dict: 18  3x6  6个decoder的损失权重   6*(loss_ce+loss_giou+loss_bbox)
        self.eos_coef = eos_coef           # 0.1
        self.losses = losses               # list: 3  ['labels', 'boxes', 'cardinality']
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef   # tensro: 92   前91=1  92=eos_coef=0.1
        self.register_buffer('empty_weight', empty_weight)
        
    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
                      dict: 'pred_logits'=Tensor[bs, 100, 92个class]  'pred_boxes'=Tensor[bs, 100, 4]  最后一个decoder层输出
                             'aux_output'={list:5}  0-4  每个都是dict:2 pred_logits+pred_boxes 表示5个decoder前面层的输出
             targets: list of dicts, such that len(targets) == batch_size.   list: bs
                      每张图片包含以下信息:'boxes'、'labels'、'image_id'、'area'、'iscrowd'、'orig_size'、'size'
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        # dict: 2   最后一个decoder层输出  pred_logits[bs, 100, 92个class] + pred_boxes[bs, 100, 4]
        outputs_without_aux = {
    
    k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # 匈牙利算法  解决二分图匹配问题  从100个预测框中找到和N个gt框一一对应的预测框  其他的100-N个都变为背景
        # Retrieve the matching between the outputs of the last layer and the targets  list:1
        # tuple: 2    0=Tensor3=Tensor[5, 35, 63]  匹配到的3个预测框  其他的97个预测框都是背景
        #             1=Tensor3=Tensor[1, 0, 2]    对应的三个gt框
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)   # int 统计这整个batch的所有图片的gt总个数  3
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()   # 3.0

        # 计算最后层decoder损失  Compute all the requested losses
        losses = {
    
    }
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # 计算前面5层decoder损失  累加到一起  得到最终的losses
        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)   # 同样匈牙利算法匹配
                for loss in self.losses:   # 计算各个loss
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {
    
    }
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {
    
    'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {
    
    k + f'_{
      
      i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)
        # 参加权重更新的损失:losses: 'loss_ce' + 'loss_bbox' + 'loss_giou'    用于log日志: 'class_error' + 'cardinality_error'
        return losses

The whole function is mainly doing two things:

  1. Call the self.matcher function to match N (gt number) real prediction frames from 100 prediction frames, and match the prediction frame corresponding to each gt frame;
  2. Call self.get_loss to calculate each loss

1.1, Hungarian algorithm, bipartite graph matching: self.matcher

For the principle of Hungarian algorithm, you can take a look at this classic blog: Algorithm Study Notes (5): Hungarian Algorithm

In DETR, the HungarianMatcher class in models/matcher.py implements the Hungarian matching algorithm:

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
    
    # 不需要更新梯度  只是一种匹配方式
    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes]=[bs,100,92] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4]=[bs,100,4] with the predicted box coordinates

            targets: list:bs This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes]=[3] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        # batch_size  100
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        # [2,100,92] -> [200, 92] -> [200, 92]概率
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        # [2,100,4] -> [200, 4]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        # [3]  idx = 32, 1, 85  concat all labels
        tgt_ids = torch.cat([v["labels"] for v in targets])
        # [3, 4]  concat all box
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # 计算损失   分类 + L1 box + GIOU box
        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

        # Final cost matrix   [100, 3]  bs*100个预测框分别和3个gt框的损失矩阵
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()  # [bs, 100, 3]

        sizes = [len(v["boxes"]) for v in targets]   # gt个数 3

        # 匈牙利算法进行二分图匹配  从100个预测框中挑选出最终的3个预测框 分别和gt计算损失  这个组合的总损失是最小的
        # 0: [3]  5, 35, 63   匹配到的gt个预测框idx
        # 1: [3]  1, 0, 2     对应的gt idx
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        
        # list: bs  返回bs张图片的匹配结果
        # 每张图片都是一个tuple:2
        # 0 = Tensor[gt_num,]  匹配到的正样本idx       1 = Tensor[gt_num,]  gt的idx
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In fact, it is to first calculate the total loss of each prediction box (100) and each gt box to form a loss matrix C, and then call the Hungarian algorithm written by scipy.optimize.linear_sum_assignment. The principle of matching is the minimum "loss sum" (here The loss is not the real loss, here is just a measurement method, which is different from the calculation method of loss), and the only responsible prediction frame corresponding to each gt is obtained, and other prediction frames will be automatically classified as background.

linear_sum_assignment, input a metric matrix (cost matrix) of a bipartite graph, calculate the minimum weight assignment method of the metric matrix of this bipartite graph, and return the matrix row index (prediction box idx) and column index (gt box idx) corresponding to the matching scheme .

1.2. Calculate loss: self.get_loss

self.get_loss is a function defined by the SetCriterion class:

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
    
    
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {
      
      loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

At the same time, the classification loss (self.loss_labels), regression loss (self.boxes) and cardinality loss are called. However, the cardinality loss is only used for log and does not participate in gradient updates, so it will not be described here. In addition, if it is a segmentation task, there is also a mask segmentation loss calculation, which will not be described here for the time being.

1.2.1. Classification loss: self.loss_labels

Classification loss self.loss_labels:

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        outputs:'pred_logits'=[bs, 100, 92] 'pred_boxes'=[bs, 100, 4] 'aux_outputs'=5*([bs, 100, 92]+[bs, 100, 4])
        targets:'boxes'=[3,4] labels=[3] ...
        indices: [3] 如:5,35,63  匹配好的3个预测框idx
        num_boxes:当前batch的所有gt个数
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']  # 分类:[bs, 100, 92类别]

        # idx tuple:2  0=[num_all_gt] 记录每个gt属于哪张图片  1=[num_all_gt] 记录每个匹配到的预测框的index
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        # 正样本+负样本  上面匹配到的预测框作为正样本 正常的idx  而100个中没有匹配到的预测框作为负样本(idx=91 背景类)
        target_classes[idx] = target_classes_o

        # 分类损失 = 正样本 + 负样本
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {
    
    'loss_ce': loss_ce}

        # 日志 记录Top-1精度
        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

        # losses: 'loss_ce': 分类损失
        #         'class_error':Top-1精度 即预测概率最大的那个类别与对应被分配的GT类别是否一致  这部分仅用于日志显示 并不参与模型训练
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        # [num_all_gt]  记录每个gt都是来自哪张图片的 idx
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        # 记录匹配到的预测框的idx
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

Notice:

  1. classification loss = cross entropy loss;
  2. Positive samples + negative samples = 100, number of positive samples = number of GTs, number of negative samples = 100 - number of GTs;
  3. 92 categories, idx=91 means background category;
  4. Note that there is a _get_src_permutation_id function here, which is mainly about flattening the prediction frame. Originally, it had the dimension of batch, but now it is flattened to one dimension, which is convenient for subsequent calculation of losses;
  5. A class_error: Top-1 accuracy is also calculated here for log display;

1.2.2, regression loss: self.boxes

regression loss self.boxes:

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        outputs:'pred_logits'=[bs, 100, 92] 'pred_boxes'=[bs, 100, 4] 'aux_outputs'=5*([bs, 100, 92]+[bs, 100, 4])
        targets:'boxes'=[3,4] labels=[3] ...
        indices: [3] 如:5,35,63  匹配好的3个预测框idx
        num_boxes:当前batch的所有gt个数
        """
        assert 'pred_boxes' in outputs
        # idx tuple:2  0=[num_all_gt] 记录每个gt属于哪张图片  1=[num_all_gt] 记录每个匹配到的预测框的index
        idx = self._get_src_permutation_idx(indices)

        # [all_gt_num, 4]  这个batch的所有正样本的预测框坐标
        src_boxes = outputs['pred_boxes'][idx]
        # [all_gt_num, 4]  这个batch的所有gt框坐标
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        # 计算L1损失
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {
    
    }
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        # 计算GIOU损失
        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes

        # 'loss_bbox': L1回归损失   'loss_giou': giou回归损失  
        return losses

Notice:

  1. Regression loss: only calculate the regression loss of all positive samples;
  2. Regression Loss = L1 Loss + GIOU Loss

Two, bbox post-processing: PostProcess

This part is the test link. After the forward propagation, the loss is calculated for log display, and the coco index is calculated.

Also first define the post-processing function in detr.py:

# 定义后处理
postprocessors = {
    
    'bbox': PostProcess()}	

Then call the PostProcess function after the forward reasoning in the evaluate of engine.py to post-process the predicted 100 boxes:

# 前向传播
outputs = model(samples)
# 后处理
# orig_target_sizes = [bs, 2]  bs张图片的原图大小
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
# list: bs    每个list都是一个dict  包括'scores'  'labels'  'boxes'三个字段
# scores = Tensor[100,]  这张图片预测的100个预测框概率分数
# labels = Tensor[100,]  这张图片预测的100个预测框所属类别idx
# boxes = Tensor[100, 4] 这张图片预测的100个预测框的绝对位置坐标(相对这张图片的原图大小的坐标)
results = postprocessors['bbox'](outputs, orig_target_sizes)

PostProcess class:

class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""
    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        """ Perform the computation
        Parameters:
            outputs: raw outputs of the model
                     0 pred_logits 分类头输出[bs, 100, 92(类别数)]
                     1 pred_boxes 回归头输出[bs, 100, 4]
                     2 aux_outputs list: 5  前5个decoder层输出 5个pred_logits[bs, 100, 92(类别数)] 和 5个pred_boxes[bs, 100, 4]
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
        """
        # out_logits:[bs, 100, 92(类别数)]
        # out_bbox:[bs, 100, 4]
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        # [bs, 100, 92]  对每个预测框的类别概率取softmax
        prob = F.softmax(out_logits, -1)
        # prob[..., :-1]: [bs, 100, 92] -> [bs, 100, 91]  删除背景
        # .max(-1): scores=[bs, 100]  100个预测框属于最大概率类别的概率
        #           labels=[bs, 100]  100个预测框的类别
        scores, labels = prob[..., :-1].max(-1)

        # cxcywh to xyxy  format   [bs, 100, 4]
        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        # and from relative [0, 1] to absolute [0, height] coordinates  bs张图片的宽和高
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]  # 归一化坐标 -> 绝对位置坐标(相对于原图的坐标)  [bs, 100, 4]

        results = [{
    
    'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]

        # list: bs    每个list都是一个dict  包括'scores'  'labels'  'boxes'三个字段
        # scores = Tensor[100,]  这张图片预测的100个预测框概率分数
        # labels = Tensor[100,]  这张图片预测的100个预测框所属类别idx
        # boxes = Tensor[100, 4] 这张图片预测的100个预测框的绝对位置坐标(相对这张图片的原图大小的坐标)
        return results

It can be seen that the post-processing is actually to count the prediction results, remove the background class, and obtain the probability scores of the categories, labels, and absolute position coordinates of the 100 prediction boxes predicted for each picture.

Then finally send this result to coco_evaluator to calculate coco related indicators.

When predicting, in fact, our final predicted objects generally do not have 100 objects. How do we deal with it at this time? Generally, a threshold (0.7) of the prediction probability score is set, and the prediction frames larger than this prediction will be retained and displayed in the end, and those prediction frames smaller than the prediction will be discarded.

Reference

Official source code: https://github.com/facebookresearch/detr

Explanation of the source code of station b: iron-clad assembly line workers

Zhihu [Brother Buffalo]: Interpretation of DETR source code

CSDN [squirrel working hard] source code explanation: DETR source code notes (1)

CSDN [squirrel working hard] source code explanation: DETR source code notes (2)

Knowing that CV will not be wiped out- [source code analysis target detection cross-border star DETR (1), overview and model inference]

Knowing that CV will not be wiped out- [source code analysis target detection cross-border star DETR (2), model training process and data processing]

Knowing that CV will not be wiped out- [Source code analysis target detection cross-border star DETR (3), Backbone and position encoding]

Knowing that CV will not be wiped out- [Source code analysis target detection cross-border star DETR (4), Detection with Transformer]

Knowing that CV will not be wiped out- [source code analysis target detection crossover star DETR (5), loss function and Hungarian matching algorithm]

Knowing that CV will not be wiped out- [source code analysis target detection cross-border star DETR (6), model output and prediction generation]

Guess you like

Origin blog.csdn.net/qq_38253797/article/details/127618402