Detailed explanation of the Loss module of mmdet


foreword

 This article introduces the loss function part of mmdet, and will gradually expand the use precautions and usage methods of the loss function in mmdet.


1. Introduction to the loss function module in mmdet

1.1. Loss register

 First look at the code: mmdet/models/builder.py

from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

MODELS = Registry('models', parent=MMCV_MODELS) # 此处多了一个parent参数,暂时不予考虑

BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS         # Loss 注册器
DETECTORS = MODELS

Here the MODELS register is given to other modules at the same time, why the operation will be in the follow-up

1.2. Register L1 Loss()

@LOSSES.register_module()
class L1Loss(nn.Module):
    """L1 loss.

    Args:
        reduction (str, optional): The method to reduce the loss.
            Options are "none", "mean" and "sum".
        loss_weight (float, optional): The weight of loss.
    """

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(L1Loss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): 预测框. 比如[N];
            target (torch.Tensor): 真实值.比如[N];
            weight (torch.Tensor, optional): 每个样本的权重,shape = [N], Defaults to None.
            avg_factor (int, optional): 控制总损失的系数,作用跟loss_weight重了。Defaults to None.
            reduction_override (str, optional): 作用跟reduction重了. Defaults to None.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * l1_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

 The above initialization parameters are relatively simple, just two parameters: reduction defaults to 'mean', which returns the mean value of loss, and loss_weight controls the total weight value of L1 Loss. But there are more parameters in the forward part:Needless to say, pred and target should have the same shape, assuming that the shape of the bbox is [1000,4]; the shape of weight should be the same as the shape of pred, and control the weight value of each sample to the total loss; avg_factor It is not used much with reduction_override. These two parameters collide with loss_weight and reduction parameters respectively, so don't worry about it.
 After understanding the functions of the above parameters, let’s take a practical example to calculate:

import torch
from mmdet.models import build_loss

loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)

# 模块计算
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = obj(pred, target)
print(loss, 9/8)

 It is found that the result is consistent with the actual hand calculation. Let’s briefly talk about the calculation process: calculate the absolute value between each element through torch.abs, and then get the final result by the .mean() method, where the number of all elements is divided. For example, here is 2*4=8.
 Here is a version with weight:

import torch
from mmdet.models import build_loss

loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)

# 模块计算
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
# 带weight版本的: 最后一个元素的weight =0
weight = torch.Tensor([[1,1,1,1],[1,1,1,0]])     # [2,4]
loss = obj(pred, target, weight)
print(loss, 8/8)

1.3. Internal implementation logic

 In essence, the decorator used realizes the encapsulation of loss. Briefly speaking, the calling process:
1) Call the forward method, and internally call the l1_loss function;

@weighted_loss
def l1_loss(pred, target):
    """L1 loss.

    Args:
        pred (torch.Tensor): The prediction.
        target (torch.Tensor): The learning target of the prediction.

    Returns:
        torch.Tensor: Calculated loss
    """
    if target.numel() == 0:
        return pred.sum() * 0

    assert pred.size() == target.size()
    loss = torch.abs(pred - target)  # 对应元素相减
    return loss

2) meet at this time@weighted_loss decorator, jump into the decorator first, Note that the l1 loss function is not calculated first at this time, mmdet/losses/losses/utils.py

def weighted_loss(loss_func):
    @functools.wraps(loss_func)
    def wrapper(pred,
                target,
                weight=None,
                reduction='mean',
                avg_factor=None,
                **kwargs):
        # 获取每个元素之间损失
        loss = loss_func(pred, target, **kwargs) 
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss

    return wrapper

 first toloss_func, that is, l1_loss, is packaged once, that is, some parameters **kwargs are stuffed into it, and then l1_loss is executed at this time to obtain the loss value between each element.
3) In the last step, execute weight_reduce_loss to get the final form of the loss (weight, reduction, avg_factor):

def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are "none", "mean" and "sum".

    Return:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()

@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):

    # if weight is specified, apply element-wise weight
    if weight is not None:
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == 'mean':
            loss = loss.sum() / avg_factor
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != 'none':
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss

1.4. Summary

 Basically, the calculation process of all losses in mmdet is the above process. When using L1 Loss, you don’t need to care about so many hyperparameters. You can directly build loss and then pass in pred and target. The rest of the parameters are basically default.

Summarize

 To be continued...

Guess you like

Origin blog.csdn.net/wulele2/article/details/125469970