Detailed explanation of CenterFusion loss function initialization and forward propagation process

1. Loss function initialization pre-order operation logic

In line 51 of the main() function of CenterFusion/src/main.py, the Trainer class is called

trainer = Trainer(opt, model, optimizer)

At this time, the Trainer's __init__() function is called. The function is defined in line 130 of CenterFusion/src/lib/trainer.py.

def __init__(
  self, opt, model, optimizer=None):
  self.opt = opt
  self.optimizer = optimizer
  self.loss_stats, self.loss = self._get_losses(opt) #loss函数初始化
  self.model_with_loss = ModelWithLoss(model, self.loss, opt) #model_with_loss初始化

2. Loss function initialization

2.1 Loss function initialization - _get_losses() function in the Trainer class

First, we introduce the initialization part of the loss function, the _get_losses() function. This function is defined in line 238 of CenterFusion/src/lib/trainer.py.
Note: The call of the _get_losses() function is in line 134 of CenterFusion/src/lib/trainer.py.

  def _get_losses(self, opt):
    loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \
      'hp_offset', 'dep', 'dep_sec', 'dim', 'rot', 'rot_sec',
      'amodel_offset', 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity']
    # 首先定义最后模型需要回归的头,之后根据这里的头来调用相应的损失函数进行回归
    loss_states = ['tot'] + [k for k in loss_order if k in opt.heads]
    # 计算后loss_states为['tot', 'hm', 'wh', 'reg', 'dep', 'dep_sec', 'dim', 'rot', 'rot_sec', 'amodel_offset', 'nuscenes_att', 'velocity']
    loss = GenericLoss(opt) #这里是损失函数的调用,都封装在了GenericLoss类中
    return loss_states, loss

The loss function class GenericLoss() is defined in line 23 of CenterFusion/src/lib/trainer.py, which calls the constructor of the Module parent class and is located at anaconda3/envs/pytorch17/lib/python3.7/site-packages/torch/ Line 223 of nn/modules/module.py

class GenericLoss(torch.nn.Module):
  def __init__(self, opt):
    super(GenericLoss, self).__init__()       #调用Module父类的构造函数对GenericLoss初始化
    self.crit = FastFocalLoss(opt=opt)        #FocalLoss用于之后的分类损失

Only the initialization part of the FastFocalLoss class is shown here. The FastFocalLoss class definition is located in line 72 of CenterFusion/src/lib/model/losses.py. Later, when the specific operation of the function is introduced, all forward functions and the corresponding functions that need to be called will be put. , this part also mainly introduces the init function for the classes that will be used later.

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  重新实现了原版centernet的focal loss,与CornerNet版本完全相同,速度更快且需要的内存较少
  '''
  def __init__(self, opt=None):
    super(FastFocalLoss, self).__init__()  #调用Module父类的构造函数对FastFocalLoss初始化
    self.only_neg_loss = _only_neg_loss   #这里调用了负样本的损失函数 

The loss for negative samples in the paper is
( 1 − Y xyc ) β ( Y ^ xyc ) α log ⁡ ( 1 − Y ^ xyc ) \left(1-Y_{xyc}\right)^{\beta}\left( \widehat{Y}_{xyc}\right)^{\alpha} \log \left(1-\widehat{Y}_{xyc}\right)(1Yx yc)b(Y x yc)alog(1Y x yc)
where groundtruth isY xyc Y_{xyc}Yx yc
The predicted label is Y ^ xyc \widehat Y_{xyc}Y x yc
The _only_neg_loss() function is written based on this formula. _only_neg_loss() is located at line 67 of CenterFusion/src/lib/model/losses.py:

def _only_neg_loss(pred, gt):
  gt = torch.pow(1 - gt, 4)
  neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * gt
  return neg_loss.sum()

The following losses can be looked at in detail during debugging. They are all function calls in CenterFusion/src/lib/model/losses.py. They are similar to the FastFocalLoss class, so we will not introduce them one by one. The initialization of all loss functions is Call the constructor of the Module parent class for initialization

    self.crit_reg = RegWeightedL1Loss()         #用于计算hm的回归损失
     if 'rot' in opt.heads:
      self.crit_rot = BinRotLoss()              #用于计算alpha观测角的损失
    if 'nuscenes_att' in opt.heads:
      self.crit_nuscenes_att = WeightedBCELoss()#用于计算物体类别的损失
    self.opt = opt
    self.crit_dep = DepthLoss()                 #用于估计深度的损失

Insert image description here

2.2 Initialization of model_with_loss

Next, we will introduce the initialization of model_with_loss. This variable initialization uses the ModelWithLoss class. This class is defined in line 110 of CenterFusion/src/lib/trainer.py. Its __init__() function is used for initialization. This initialization is mainly to facilitate subsequent steps. Call to the forward() function
Note: The call to the ModelWithLoss class is in CenterFusion/src/lib/trainer.py line 134

self.model_with_loss = ModelWithLoss(model, self.loss, opt) 
def __init__(self, model, loss, opt):
  super(ModelWithLoss, self).__init__()
  self.opt = opt     #定义传入opt变量
  self.model = model #定义传入model变量
  self.loss = loss   #定义传入loss变量

3. Pre-order running logic of loss calculation process

In line 84 of the main() function of CenterFusion/src/main.py, the train() function in the Trainer class is called.

log_dict_train, _ = trainer.train(epoch, train_loader)

The train() function is defined in line 405 of CenterFusion/src/lib/trainer.py

def train(self, epoch, data_loader):
  return self.run_epoch('train', epoch, data_loader)  #调用run_epoch()函数

The run_epoch() function is defined in line 150 of CenterFusion/src/lib/trainer.py, and the function that starts the loss calculation is located in line 178

# run one iteration 
output, loss, loss_stats = model_with_loss(batch, phase)

4. Loss calculation process

4.1 Loss calculation function calling relationship

The entry function for loss calculation is the model_with_loss() function in run_epoch(), located at line 178 of CenterFusion/src/lib/trainer.py. The loss variable is the obtained loss operation result.

# run one iteration 
output, loss, loss_stats = model_with_loss(batch, phase)

The definition of model_with_loss is located at line 151 of CenterFusion/src/lib/trainer.py. Define model_with_loss as self.model_with_loss. This statement is located at line 151 of CenterFusion/src/lib/trainer.py.

model_with_loss = self.model_with_loss

And self.model_with_loss is initialized to the ModelWithLoss class in the __init__() function of the Trainer class, located at line 135 of CenterFusion/src/lib/trainer.py

self.model_with_loss = ModelWithLoss(model, self.loss, opt)

Therefore, the model_with_loss(batch, phase) function on line 178 of CenterFusion/src/lib/trainer.py is to call the forward() function in the ModelWithLoss class, and its location is on line 117 of CenterFusion/src/lib/trainer.py

def forward(self, batch, phase):
  pc_dep = batch.get('pc_dep', None)
  pc_hm = batch.get('pc_hm', None)
  calib = batch['calib'].squeeze(0)

  ## run the first stage
  outputs = self.model(batch['image'], pc_hm=pc_hm, pc_dep=pc_dep, calib=calib)
  
  loss, loss_stats = self.loss(outputs, batch) #计算loss
  return outputs[-1], loss, loss_stats

The calculation of loss is the self.loss() function, which is defined in the __init__() function of ModelWithLoss. Its location is line 115 of CenterFusion/src/lib/trainer.py.

self.loss = loss

The loss variable on the right side of the equal sign is the parameter passed in by the __init__() function, located at line 111 of CenterFusion/src/lib/trainer.py

class ModelWithLoss(torch.nn.Module):
  def __init__(self, model, loss, opt):  #CenterFusion/src/lib/trainer.py的111行
  super(ModelWithLoss, self).__init__()
  self.opt = opt
  self.model = model
  self.loss = loss

This parameter is passed in when initializing self.model_with_loss to the ModelWithLoss class in the __init__() function of the Trainer class, located at line 135 of CenterFusion/src/lib/trainer.py

self.model_with_loss = ModelWithLoss(model, self.loss, opt)

self.loss is the output result of the _get_losses() function in the Trainer class, located at line 134 of CenterFusion/src/lib/trainer.py

self.loss_stats, self.loss = self._get_losses(opt)

The loss operation and output of the _get_losses() function are located at lines 243 and 244 of CenterFusion/src/lib/trainer.py, which calls the GenericLoss class

def _get_losses(self, opt):
  loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \
    'hp_offset', 'dep', 'dep_sec', 'dim', 'rot', 'rot_sec',
    'amodel_offset', 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity']
  loss_states = ['tot'] + [k for k in loss_order if k in opt.heads]
  loss = GenericLoss(opt)    #CenterFusion/src/lib/trainer.py的243行
  return loss_states, loss   #CenterFusion/src/lib/trainer.py的244行

Therefore, the self.loss() function on line 125 of CenterFusion/src/lib/trainer.py actually calls the forward() function in the GenericLoss class.

loss, loss_stats = self.loss(outputs, batch)  #CenterFusion/src/lib/trainer.py的125行

4.2 The actual calculation function of the loss function - the forward() function in the GenericLoss class

4.2.1 Normalized_sigmoid_output() function

The forward() function in the GenericLoss class is located at line 46 of CenterFusion/src/lib/trainer.py

def forward(self, outputs, batch):
  opt = self.opt
  losses = {
    
    head: 0 for head in opt.heads}

  for s in range(opt.num_stacks):
    output = outputs[s]
    output = self._sigmoid_output(output) #将输出归一化到0-1之间

_sigmoid_output() is located at line 35 of CenterFusion/src/lib/trainer.py. Its main function is to normalize the output input to a range of 0~1 for output.

def _sigmoid_output(self, output):
  if 'hm' in output:
    output['hm'] = _sigmoid(output['hm'])
  if 'hm_hp' in output:
    output['hm_hp'] = _sigmoid(output['hm_hp'])
  if 'dep' in output:
    output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.
  if 'dep_sec' in output and self.opt.sigmoid_dep_sec:
    output['dep_sec'] = 1. / (output['dep_sec'].sigmoid() + 1e-6) - 1.
  return output


Basically, the _sigmoid() function is called, which is located in line 8 of CenterFusion/src/lib/model/utils.py. The clamp() function torch.clamp() function in the torch library is mainly used. It clamps the input tensor x.sigmoid_() to the interval [min,max], and the return value is a new tensor between [min,max]

def _sigmoid(x):
  y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
  return y

4.2.2 FastFocalLoss

Returning to the forward() function in the GenericLoss class, self.crit() is called, located at line 55 of CenterFusion/src/lib/trainer.py

    if 'hm' in output:
      losses['hm'] += self.crit(
        output['hm'], batch['hm'], batch['ind'], 
        batch['mask'], batch['cat']) / opt.num_stacks

The self.crit() function is called in the above statement, and the FastFocalLoss class used for initialization of self.crit is located at line 26 of CenterFusion/src/lib/trainer.py

self.crit = FastFocalLoss(opt=opt)

Therefore, the self.crit() function in line 55 of CenterFusion/src/lib/trainer.py in the forward() function in the GenericLoss class actually calls the forward() function of the FastFocalLoss class. The forward() function of the FastFocalLoss class is located in CenterFusion
/ Line 81 of src/lib/model/losses.py

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self, opt=None):
    pass

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W   1*10*112*200
      ind, mask: B x M              
      cat (category id for peaks): B x M
    '''
    neg_loss = self.only_neg_loss(out, target)                       #计算负样本的损失
    pos_pred_pix = _tranpose_and_gather_feat(out, ind) # B x M x C

        def _tranpose_and_gather_feat(feat, ind):  #CenterFusion/src/lib/model/utils.py的22行
          feat = feat.permute(0, 2, 3, 1).contiguous()     #首先将feat从B*C*H*W转换成 BHWC
          feat = feat.view(feat.size(0), -1, feat.size(3)) #之后在变成B*HW*C
          '''
          这里进行维度转换的原因是由于当时在dataset中设置中心点的时候,也是首先将图像展开成向量,之后保存中心点在向量中的位置
          '''
          feat = _gather_feat(feat, ind)
          '''
          _gather_feat函数如下边所示,ind中保存的就是原来2d框中心点在图像中的位置,
          这里由于要使用gather函数获取到预测特征对应位置的值,所以将ind由原来的1*128大小扩展成1*128*10
          加上原来1*128中的第一个值为7458,那么1*128*10中的第一行向量所有的值都为7458,这里是为了取出10类类别中所有的位置
          用于之后的类别二次删选。
          最后得到的特则为1*128*10
          '''
          return feat
        
        def _gather_feat(feat, ind):  ##CenterFusion/src/lib/model/utils.py的16行
          dim = feat.size(2)
          ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
          feat = feat.gather(1, ind)
          return feat

    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M    
    '''
    根据每个框的值进行二次删选,加入第一个框也就是上述7458位置对应的类别为第一类,
    那么只是取出了feat中的第一个样本的第一个元素做为预测的概率
    '''
    num_pos = mask.sum()
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    #根据focalloss的正样本的公式进行计算损失
    pos_loss = pos_loss.sum() #改图像中对应的所有的bbox的个数
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos #计算平均损失

4.2.3 DepthLoss

Returning to the forward() function in the GenericLoss class, self.crit_dep() is called, located at line 60 of CenterFusion/src/lib/trainer.py

    if 'dep' in output:
      losses['dep'] += self.crit_dep(
        output['dep'], batch['dep'], batch['ind'], 
        batch['dep_mask'], batch['cat']) / opt.num_stacks

The self.crit_dep() function is called in the above statement, and the DepthLoss class used for initialization of self.crit_dep is located at line 33 of CenterFusion/src/lib/trainer.py

self.crit_dep = DepthLoss()

Therefore, the self.crit_dep() function in line 60 of CenterFusion/src/lib/trainer.py in the forward() function in the GenericLoss class actually calls the forward() function of the DepthLoss class. The forward() function of the DepthLoss class is located in CenterFusion
/ Line 212 of src/lib/model/losses.py

class DepthLoss(nn.Module):
  def __init__(self, opt=None):
    super(DepthLoss, self).__init__()

  def forward(self, output, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    pred = _tranpose_and_gather_feat(output, ind) # B x M x (C) #获取框中心点对应的预测的深度值
    if pred.shape[2] > 1:
      pred = pred.gather(2, cat.unsqueeze(2)) # B x M
    loss = F.l1_loss(pred * mask, target * mask, reduction='sum') #利用torch.nn.functional中的l1损失计算depth的损失值
    loss = loss / (mask.sum() + 1e-4)
    return loss

4.2.4 RegWeightedL1Loss

Return to the forward() function in the GenericLoss class, define the regression header, and call self.crit_reg(), located at line 70 of CenterFusion/src/lib/trainer.py

      regression_heads = [
        'reg', 'wh', 'tracking', 'ltrb', 'ltrb_amodal', 'hps', 
        'dim', 'amodel_offset', 'velocity'] #定义回归头

      for head in regression_heads:         #利用回归损失计算其他回归头的损失
        if head in output:
          losses[head] += self.crit_reg(
            output[head], batch[head + '_mask'],
            batch['ind'], batch[head]) / opt.num_stacks

The self.crit_reg() function is called in the above statement, and the RegWeightedL1Loss class is used for initialization of self.crit_reg. The initialization is located at line 27 of CenterFusion/src/lib/trainer.py

self.crit_reg = RegWeightedL1Loss()

Therefore, the self.crit_reg() function in line 70 of CenterFusion/src/lib/trainer.py in the forward() function in the GenericLoss class actually calls the forward() function of the RegWeightedL1Loss class. The forward() function of the RegWeightedL1Loss class is located in CenterFusion
/ Line 122 of src/lib/model/losses.py

class RegWeightedL1Loss(nn.Module):
  def __init__(self):
    super(RegWeightedL1Loss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    pred = _tranpose_and_gather_feat(output, ind)
    loss = F.l1_loss(pred * mask, target * mask, reduction='sum') #利用torch.nn.functional中的l1损失计算d回归损失
    loss = loss / (mask.sum() + 1e-4)
    return loss

4.2.5 Unexecuted part

Return to the forward() function in the GenericLoss class. The following two steps are not executed because there are no relevant parameters in the output. They are located at line 75 of CenterFusion/src/lib/trainer.py.

      if 'hm_hp' in output:                              #没有执行
        losses['hm_hp'] += self.crit(
          output['hm_hp'], batch['hm_hp'], batch['hp_ind'], 
          batch['hm_hp_mask'], batch['joint']) / opt.num_stacks
        if 'hp_offset' in output:                       #没有执行
          losses['hp_offset'] += self.crit_reg(
            output['hp_offset'], batch['hp_offset_mask'],
            batch['hp_ind'], batch['hp_offset']) / opt.num_stacks

4.2.6 BinRotLoss

Continuing with the forward() function in the GenericLoss class, self.crit_rot() is called, located at line 84 of CenterFusion/src/lib/trainer.py

      if 'rot' in output:
        losses['rot'] += self.crit_rot(                                    #用于回归alpha观测角
          output['rot'], batch['rot_mask'], batch['ind'], batch['rotbin'],
          batch['rotres']) / opt.num_stacks

The self.crit_rot() function is called in the above statement, and the BinRotLoss class used for initialization of self.crit_rot is located at line 29 of CenterFusion/src/lib/trainer.py

if 'rot' in opt.heads:
  self.crit_rot = BinRotLoss()

Therefore, the self.crit_rot() function in line 84 of CenterFusion/src/lib/trainer.py in the forward() function in the GenericLoss class actually calls the forward() function of the BinRotLoss class. The forward() function of the BinRotLoss class is located in CenterFusion
/ Line 149 of src/lib/model/losses.py

class BinRotLoss(nn.Module):
  def __init__(self):
    super(BinRotLoss, self).__init__()
  
  def forward(self, output, mask, ind, rotbin, rotres):
    pred = _tranpose_and_gather_feat(output, ind)          #获取对应位置的alpha观测角相关预测值
    loss = compute_rot_loss(pred, rotbin, rotres, mask)
    return loss

Among them, the compute_rot_loss() function is called, which is located at line 173 of CenterFusion/src/lib/model/losses.py

def compute_rot_loss(output, target_bin, target_res, mask):
    # output: (B, 128, 8) [bin1_cls[0], bin1_cls[1], bin1_sin, bin1_cos, 
    #                 bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos]
    # target_bin: (B, 128, 2) [bin1_cls, bin2_cls]
    # target_res: (B, 128, 2) [bin1_res, bin2_res]
    # mask: (B, 128, 1)
    '''
    这里用于计算alpha观测角的损失,作者在计算该部分的时候将分成区域+对应区域的残差角进行预测
    首先见360°分成两个相交的空间(bin),之后计算预测角度在该区域内的残差角
    如果残差角在相交的区域,那么bin1,bin2对应的值都为1,如果只在一个区域内,则对应位置的值为1,另一个为0
    在预测残差角的时候用的是sin 和cos两个进行预测,
    '''
    output = output.view(-1, 8)
    target_bin = target_bin.view(-1, 2)
    target_res = target_res.view(-1, 2)
    mask = mask.view(-1, 1)
    #前四维的表示bin1的相关预测值,前两个表示在那个bin中,后两个是对应的sin cos值,后四维表示bin2
    #这里是计算bin对应的损失,由于是分类损失用的是交叉熵损失
    loss_bin1 = compute_bin_loss(output[:, 0:2], target_bin[:, 0], mask)        
    loss_bin2 = compute_bin_loss(output[:, 4:6], target_bin[:, 1], mask)
    
    #这里计算残差角的损失
    #如果groundtruth对应bin位置有值的话,则计算对应的sin,cos损失
    loss_res = torch.zeros_like(loss_bin1)
    if target_bin[:, 0].nonzero().shape[0] > 0:
        idx1 = target_bin[:, 0].nonzero()[:, 0]
        valid_output1 = torch.index_select(output, 0, idx1.long())
        valid_target_res1 = torch.index_select(target_res, 0, idx1.long())
        loss_sin1 = compute_res_loss(
          valid_output1[:, 2], torch.sin(valid_target_res1[:, 0]))
        loss_cos1 = compute_res_loss(
          valid_output1[:, 3], torch.cos(valid_target_res1[:, 0]))
        loss_res += loss_sin1 + loss_cos1
    if target_bin[:, 1].nonzero().shape[0] > 0:
        idx2 = target_bin[:, 1].nonzero()[:, 0]
        valid_output2 = torch.index_select(output, 0, idx2.long())
        valid_target_res2 = torch.index_select(target_res, 0, idx2.long())
        loss_sin2 = compute_res_loss(
          valid_output2[:, 6], torch.sin(valid_target_res2[:, 1]))
        loss_cos2 = compute_res_loss(
          valid_output2[:, 7], torch.cos(valid_target_res2[:, 1]))
        loss_res += loss_sin2 + loss_cos2
    return loss_bin1 + loss_bin2 + loss_res

4.2.7 WeightedBCELoss

Returning to the forward() function in the GenericLoss class, self.crit_nuscenes_att() is called, located at line 89 of CenterFusion/src/lib/trainer.py

      if 'nuscenes_att' in output:                            #计算状态损失
        losses['nuscenes_att'] += self.crit_nuscenes_att(
          output['nuscenes_att'], batch['nuscenes_att_mask'],
          batch['ind'], batch['nuscenes_att']) / opt.num_stacks

The self.crit_nuscenes_att() function is called in the above statement, and the WeightedBCELoss class is used for initialization of self.crit_nuscenes_att. The initialization is located at line 31 of CenterFusion/src/lib/trainer.py

if 'nuscenes_att' in opt.heads:
  self.crit_nuscenes_att = WeightedBCELoss()

Therefore, the self.crit_nuscenes_att() function in line 89 of CenterFusion/src/lib/trainer.py in the forward() function in the GenericLoss class actually calls the forward() function of the WeightedBCELoss class. The forward() function of the WeightedBCELoss class is located in CenterFusion
/ Line 134 of src/lib/model/losses.py

class WeightedBCELoss(nn.Module):
  def __init__(self):
    super(WeightedBCELoss, self).__init__()
    self.bceloss = torch.nn.BCEWithLogitsLoss(reduction='none')

  def forward(self, output, mask, ind, target):
    # output: B x F x H x W
    # ind: B x M
    # mask: B x M x F
    # target: B x M x F
    pred = _tranpose_and_gather_feat(output, ind) # B x M x F
    loss = mask * self.bceloss(pred, target)  #调用torch.nn.BCEWithLogitsLoss()函数计算
    loss = loss.sum() / (mask.sum() + 1e-4)
    return loss

Among them, the torch.nn.BCEWithLogitsLoss() function is called when calculating loss.

4.2.8 DepthLoss Ⅱ

Returning to the forward() function in the GenericLoss class, self.crit_dep() is called, located at line 94 of CenterFusion/src/lib/trainer.py

      if 'dep_sec' in output:                                  #计算第二个回归头的深度损失
        losses['dep_sec'] += self.crit_dep(
          output['dep_sec'], batch['dep'], batch['ind'], 
          batch['dep_mask'], batch['cat']) / opt.num_stacks

The self.crit_dep() function has been explained in the pre-order calculation and will not be repeated here. See 4.2.3 for details.

4.2.9 BinRotLoss Ⅱ

Returning to the forward() function in the GenericLoss class, self.crit_rot() is called, located at line 99 of CenterFusion/src/lib/trainer.py

      if 'rot_sec' in output:                                 #计算第二个回归头的角度损失
        losses['rot_sec'] += self.crit_rot(
          output['rot_sec'], batch['rot_mask'], batch['ind'], batch['rotbin'],
          batch['rotres']) / opt.num_stacks

The self.crit_dep() function has been explained in the pre-order calculation and will not be repeated here. See 4.2.6 for details.

4.2.10 Calculate total loss

Return to the forward() function in the GenericLoss class, which calculates the weighted sum of all losses and is located at line 105 of CenterFusion/src/lib/trainer.py

    losses['tot'] = 0
    for head in opt.heads:                                  #计算总的损失
      losses['tot'] += opt.weights[head] * losses[head] 

    return losses['tot'], losses

Among them, the definition of weight opt.weightsd is located in line 510 of CenterFusion/src/lib/opts.py

    weight_dict = {
    
    'hm': opt.hm_weight, 'wh': opt.wh_weight,
                   'reg': opt.off_weight, 'hps': opt.hp_weight,
                   'hm_hp': opt.hm_hp_weight, 'hp_offset': opt.off_weight,
                   'dep': opt.dep_weight, 'dep_res': opt.dep_res_weight,
                   'rot': opt.rot_weight, 'dep_sec': opt.dep_weight,
                   'dim': opt.dim_weight, 'rot_sec': opt.rot_weight,
                   'amodel_offset': opt.amodel_offset_weight,
                   'ltrb': opt.ltrb_weight,
                   'tracking': opt.tracking_weight,
                   'ltrb_amodal': opt.ltrb_amodal_weight,
                   'nuscenes_att': opt.nuscenes_att_weight,
                   'velocity': opt.velocity_weight}
    opt.weights = {
    
    head: weight_dict[head] for head in opt.heads}

4.3 Loss calculation function outgoing function calling relationship

The losses['tot'] and losses in the return value of the forward() function in the GenericLoss class respectively correspond to the loss and loss_stats in the forward() function of the ModelWithLoss class located in CenterFusion/src/lib/trainer.py on line 125.

loss, loss_stats = self.loss(outputs, batch)  #CenterFusion/src/lib/trainer.py的125行

That is: loss (forward() in the ModelWithLoss class) is equal to losses['tot'] (forward() return value in the GenericLoss class), loss_stats (forward() in the ModelWithLoss class) is equal to losses (forward() return in the GenericLoss class) value)
The loss and loss_stats in the return value of the forward() function of the ModelWithLoss class respectively correspond to the return values ​​loss and loss_stats of the model_with_loss() function in the run_epoch() function on line 178 of CenterFusion/src/lib/trainer.py.

# run one iteration 
output, loss, loss_stats = model_with_loss(batch, phase) #CenterFusion/src/lib/trainer.py的178行

After that, the average of the loss value of each layer is calculated as the final loss output of the epoch. Line 181 located in CenterFusion/src/lib/trainer.py calls the mean() function of the pythonstatistics module.

# backpropagate and step optimizer 反向传播和步进优化器
loss = loss.mean()     #求损失值的平均值 CenterFusion/src/lib/trainer.py的181行

At this point, the loss calculation is completed, followed by backpropagation

Guess you like

Origin blog.csdn.net/qq_34972053/article/details/131011866