【目标检测】CenterNet2代码解读

CenterNet2代码解读

选择配置为CenterNet2_R50_1x.yaml

先解读测试过程,再分析训练。整体代码结构如下:


1.利用Resnet50生成五层特征图
features = self.backbone(images.tensor)      # 代表(8163264128)倍下采样
 # 以输入(1, 3, 768, 1344)为例,第一层为(1, 256, 96, 168),,,(1, 256, 6, 112.生成proposal
proposals, _ = self.proposal_generator(images, features, None)

3.roi_heads得到results
results, _ = self.roi_heads(images, features, proposals, None)


所有batch中目标所在点index

一、生成proposal

示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。

1.进行第一次分类与回归,(256,4)(256,1),得到reg与agn_hm两个特征图
----------------------------------------------------------------------------
clss_per_level, reg_pred_per_level, agn_hm_pred_per_level = self.centernet_head(features)
----------------------------------------------------------------------------
 ## reg_pred_per_level:(1, 4, 96, 168...(1, 4, 6, 11)
 ## clss_per_level在测试阶段为NoneType,agn_hm_pred_per_level:(1, 1, 96, 168)。。。


2.每层特征图上,得到绝对坐标值
----------------------------------------------------------------------------
grids = self.compute_grids(features)  #(16128, 2)(4032, 2)...(66, 2)
----------------------------------------------------------------------------
 ##回归得到的4参数只是一个偏移值,需得到特征图所有点的绝对坐标,计算得到box:
  h, w = feature.size()[-2:]   # 96168
  shifts_x = torch.arange(0, w * self.strides[level], step=self.strides[level],dtype=torch.float32,  device=feature.device)
  # 08162432。。。,1336
  shifts_y = torch.arange( 0, h * self.strides[level], step=self.strides[level],dtype=torch.float32, device=feature.device)
  # 08162432。。。,760
  shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
  shift_x = shift_x.reshape(-1)
  shift_y = shift_y.reshape(-1)
  grids_per_level = torch.stack((shift_x, shift_y), dim=1) + self.strides[level] // 2
  # (768*13442),大概是41220。。。1340这种,每个网格中心点


3.每层特征图的尺度大小
----------------------------------------------------------------------------
shapes_per_level = grids[0].new_tensor([(x.shape[2], x.shape[3]) for x in reg_pred_per_level]) 
---------------------------------------------------------------------------- 
 # (96,168)(48., 84)..(6,11)


4.根据阈值,筛选前1000个 proposals
self.inference(images, clss_per_level, reg_pred_per_level, agn_hm_pred_per_level, grids)
 即 proposals = self.predict_instances(grids, agn_hm_pred_per_level, reg_pred_per_level,            images.image_sizes, [None for _ in agn_hm_pred_per_level])
 ---------------------------------------------------------------------------- 
 即 self.predict_single_level(grids[l], logits_pred[l], reg_pred[l] * self.strides[l],                image_sizes, agn_hm_pred[l], l, is_proposal=is_proposal))
 ---------------------------------------------------------------------------- 
 # 将每层的logits_pred作为热图(heatmap),取阈值0.001并选前1000个目标,坐标与grid进行相加,得到每层的boxlist:
   boxlist.scores = torch.sqrt(per_box_cls)   # (1000)
   boxlist.pred_boxes = Boxes(detections)     # (1000,4)
   boxlist.pred_classes = per_class           # 1000[0]


5.5层结果做NMS
---------------------------------------------------------------------------- 
boxlists = self.nms_and_topK(boxlists)
---------------------------------------------------------------------------- 

二、RoI_head得到result

整体代码如下,共经历三次级联网络

for k in range(self.num_cascade_stages):
    if k > 0:
       proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes)
       if self.training:
           proposals = self._match_and_label_boxes(proposals, k, targets)
    predictions = self._run_stage(features, proposals, k)    # tuple:(256,81)(256,4),4为xywh
    prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals)
    head_outputs.append((self.box_predictor[k], predictions, proposals))

循环3次,每次将feature与proposal生成新的proposal,保存在 head_outputs 中。
下面分别展开各个函数:

1.self._run_stage

主要是RoIPool,以及分类和回归

box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])  # ([256, 256, 7, 7])
box_features = self.box_head[stage](box_features)                                  # ([256, 1024])
return self.box_predictor[stage](box_features)                  # 全链接(1024,81)(1024,4) 

2.self.box_predictor[k].predict_boxes

3层:Linear( 1024, 81, bias=True),Linear( 1024, 4, bias=True)


_, proposal_deltas = predictions          # ( 256,4 ) 
num_prop_per_image = [len(p) for p in proposals]    # [256]
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0)   # ([256, 4])
predict_boxes = self.box2box_transform.apply_deltas(proposal_deltas, proposal_boxes)  #再次解码

利用Roi后的回归值,再次解码proposal,过程如下

def apply_deltas(self, deltas, boxes):  #输入输出为同维tensor(256,4)
   
        deltas = deltas.float()  # ensure fp32 for decoding precision
        boxes = boxes.to(deltas.dtype)

        widths = boxes[:, 2] - boxes[:, 0]
        heights = boxes[:, 3] - boxes[:, 1]
        ctr_x = boxes[:, 0] + 0.5 * widths
        ctr_y = boxes[:, 1] + 0.5 * heights

        wx, wy, ww, wh = self.weights          # (10.0, 10.0, 5.0, 5.0)
        dx = deltas[:, 0::4] / wx
        dy = deltas[:, 1::4] / wy
        dw = deltas[:, 2::4] / ww
        dh = deltas[:, 3::4] / wh

        # Prevent sending too large values into torch.exp()
        dw = torch.clamp(dw, max=self.scale_clamp)
        dh = torch.clamp(dh, max=self.scale_clamp)      # 4.135166556742356

        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]

        x1 = pred_ctr_x - 0.5 * pred_w
        y1 = pred_ctr_y - 0.5 * pred_h
        x2 = pred_ctr_x + 0.5 * pred_w
        y2 = pred_ctr_y + 0.5 * pred_h
        pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1)
        return pred_boxes.reshape(deltas.shape)

3.self._create_proposals_from_boxes

在测试阶段这里没什么意义,输入等于输出。

// An highlighted block
var foo = 'bar';

三、联合概率预测与后处理

1. 三次级联得分求平均([256, 81])
  scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] # 对prediction中的score作  relu,得到3个(256,81)
  scores = [sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
          for scores_per_image in zip(*scores_per_stage)]   
          
2. 与首次分类得分相乘            
  scores = [(s * ps[:, None]) ** 0.5  for s, ps in zip(scores, proposal_scores)] # (256,81)(256,1)得到(256813.利用最后一次级联结果作解码,得到最终box,再后处理
  predictor, predictions, proposals = head_outputs[-1]
  boxes = predictor.predict_boxes(predictions, proposals)               # ([256, 4])
  pred_instances, _ = fast_rcnn_inference(boxes,scores, image_sizes,
                predictor.test_score_thresh,
                predictor.test_nms_thresh,
                predictor.test_topk_per_image,)   # 0.3  0.7  100

四、额外函数

1.torch.kthvalue(筛选前n个,确定阈值)

cls_scores = result.scores
image_thresh, _ = torch.kthvalue(cls_scores.cpu(),num_dets - post_nms_topk + 1)
# 例如 cls_scores中的 num_dets=2492,只需要前 post_nms_topk=256 个得分,可计算出阈值image_thresh
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]

2.NMS

from torchvision.ops import boxes as box_ops
keep = box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
# keep(2492):tensor([2645,  249, 1724,  ..., 2081, 2999, 3062], device='cuda:0')
# boxes为张量(33184, scores分数(3318, idxs为类别(3318[0], threshold为0.9
boxlist = boxlist[keep]

3.RoIpool

主要作用:输入为(n,5)的ROI ,即感兴趣区域。根据大小,将其分配到三种尺度特征图上(5种也行),然后从原来的特征金字塔上抠出对应特征图。

from torchvision.ops import RoIPool
self.level_poolers = nn.ModuleList(RoIPool(output_size, spatial_scale=scale) for scale in scales)

level_assignments = assign_boxes_to_levels( box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level)                                
# (256): [0, 2, 0, 0, 1, 0, 1, 0, 2, 1, 0, 0, 0, 2...]

for level, pooler in enumerate(self.level_poolers):
    inds = nonzero_tuple(level_assignments == level)[0]         # (179)个序列
    pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
    # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
    output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level))

# 其中,level_poolers为:
self.level_poolers = nn.ModuleList(  ROIAlign(  output_size, spatial_scale=scale, sampling_ratio=0, aligned=True )    for scale in scales  )
## scale为1/81/128

3.1 ROIAlign

from torchvision.ops import roi_align

class ROIAlign(nn.Module):
    def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
       
        super().__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio
        self.aligned = aligned

        from torchvision import __version__

        version = tuple(int(x) for x in __version__.split(".")[:2])
        # https://github.com/pytorch/vision/pull/2438
        assert version >= (0, 7), "Require torchvision >= 0.7"

    def forward(self, input, rois):
        """
        Args:
            input: NCHW images
            rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
        """
        assert rois.dim() == 2 and rois.size(1) == 5
        return roi_align(
            input,                                                     # 对应的某层特征图
            rois.to(dtype=input.dtype),         # (n,5)第一维为该层索引,如 3
            self.output_size,
            self.spatial_scale,
            self.sampling_ratio,                     # 一般为0
            self.aligned,                                      # 一般为True
        )

3.2 assign_boxes_to_levels,把box分配到不同尺度特征图

比如:输入:(256,4)。输出:(256),即 [0,0,0,1,1,0,0,0,2,2,…]

def assign_boxes_to_levels(
    box_lists: List[Boxes],
    min_level: int,
    max_level: int,
    canonical_box_size: int,
    canonical_level: int,
):

    box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))        # 2048 个box的面积开方
    level_assignments = torch.floor(
        canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
    )                                                                                                                                     #  torch.log2()  函数值域为(-62, canonical_box_size = 224
    # clamp level to (min, max), in case the box size is too large or too small
    # for the available feature maps
    level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
    return level_assignments.to(torch.int64) - min_level

五.训练

训练损失共2部分

proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)

1.proposal_losses


1._get_ground_truth

把标签映射成feature map的维度

    def _get_ground_truth(self, grids, shapes_per_level, gt_instances):
        '''
        Input:
            grids: list of tensors [(hl x wl, 2)]_l
            shapes_per_level: list of tuples L x 2:
            gt_instances: gt instances
        Retuen:
            pos_inds: N
            labels: N
            reg_targets: M x 4
            flattened_hms: M x C or M x 1
            N: number of objects in all images
            M: number of pixels from all FPN levels
        '''

        # get positive pixel index
        if not self.more_pos:
            pos_inds, labels = self._get_label_inds(
                gt_instances, shapes_per_level)                                # N, N :一个batch里,所有目标中心点所在索引 [516, 3692,7533,...55,209...,71433]
        else:
            pos_inds, labels = None, None
        heatmap_channels = self.num_classes
        L = len(grids)
        num_loc_list = [len(loc) for loc in grids]
        strides = torch.cat([
            shapes_per_level.new_ones(num_loc_list[l]) * self.strides[l] \
            for l in range(L)]).float()                                                     # M 19620: 14720*[8] + ... 240*[64] + 60*[128]
        reg_size_ranges = torch.cat([
            shapes_per_level.new_tensor(self.sizes_of_interest[l]).float().view(
            1, 2).expand(num_loc_list[l], 2) for l in range(L)]) # M x 2 (19620*2):  14720*[0, 80] + ... 240*[256, 640, ] + 60*[512, 100000]
        grids = torch.cat(grids, dim=0)                                           # M x 2 (19620*2):      (14720, 2), (3680, 2), (920, 2), ,,(60, 2 ) 
        M = grids.shape[0]

        reg_targets = []
        flattened_hms = []
        for i in range(len(gt_instances)):                                     # images
            boxes = gt_instances[i].gt_boxes.tensor # N x 4
            area = gt_instances[i].gt_boxes.area() # N
            gt_classes = gt_instances[i].gt_classes # N in [0, self.num_classes]

            N = boxes.shape[0]
            if N == 0:
                reg_targets.append(grids.new_zeros((M, 4)) - INF)
                flattened_hms.append(
                    grids.new_zeros((
                        M, 1 if self.only_proposal else heatmap_channels)))
                continue
            
            l = grids[:, 0].view(M, 1) - boxes[:, 0].view(1, N)                       # M x N (19620, 75)
            t = grids[:, 1].view(M, 1) - boxes[:, 1].view(1, N)                      # M x N
            r = boxes[:, 2].view(1, N) - grids[:, 0].view(M, 1)                      # M x N
            b = boxes[:, 3].view(1, N) - grids[:, 1].view(M, 1)                     # M x N
            reg_target = torch.stack([l, t, r, b], dim=2)                                # M x N x 4

            centers = ((boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2)                      # N x 2
            centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
            strides_expanded = strides.view(M, 1, 1).expand(M, N, 2)
            centers_discret = ((centers_expanded / strides_expanded).int() * \
                strides_expanded).float() + strides_expanded / 2            # M x N x 2  目标中心点最近的网格坐标
            
            is_peak = (((grids.view(M, 1, 2).expand(M, N, 2) - \
                centers_discret) ** 2).sum(dim=2) == 0)                               # M x N
            is_in_boxes = reg_target.min(dim=2)[0] > 0                            # M x N
            is_center3x3 = self.get_center3x3(
                grids, centers, strides) & is_in_boxes                                     # input: (M, 2) (N. 2) (M)  --> M x N
            is_cared_in_the_level = self.assign_reg_fpn(
                reg_target, reg_size_ranges)                                                      # M x N reg_target(l,t,r,b)计算面积,跟size_ranges对比
            reg_mask = is_center3x3 & is_cared_in_the_level               # M x N

            dist2 = ((grids.view(M, 1, 2).expand(M, N, 2) - \
                centers_expanded) ** 2).sum(dim=2) # M x N
            dist2[is_peak] = 0
            radius2 = self.delta ** 2 * 2 * area # N
            radius2 = torch.clamp(
                radius2, min=self.min_radius ** 2)
            weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N) # M x N            
            reg_target = self._get_reg_targets(
                reg_target, weighted_dist2.clone(), reg_mask, area) # M x 4

            if self.only_proposal:
                flattened_hm = self._create_agn_heatmaps_from_dist(
                    weighted_dist2.clone())                                                                 #    M x 1min(dist,dim=1)(M,N)映射为(M),即为每个特征点找到最近的gt,并返回距离
            else:                                                                                                                #     不执行
                flattened_hm = self._create_heatmaps_from_dist(
                    weighted_dist2.clone(), gt_classes, 
                    channels=heatmap_channels) # M x C

            reg_targets.append(reg_target)                                                        # (M, 4)
            flattened_hms.append(flattened_hm)                                           # (M, 1)
        
        # transpose im first training_targets to level first ones
        reg_targets = _transpose(reg_targets, num_loc_list)                   # 5 * [64512, 4] [16128,4]...[66, 4]
        flattened_hms = _transpose(flattened_hms, num_loc_list)     # 5 * [64512, 1] [16128,1]...[66, 1]
        for l in range(len(reg_targets)):
            reg_targets[l] = reg_targets[l] / float(self.strides[l])
        reg_targets = cat([x for x in reg_targets], dim=0)                            # MB x 4(85944, 4): 64512 + 16128 + ... + 66
        flattened_hms = cat([x for x in flattened_hms], dim=0)              # MB x C (85944, 1)
        
        return pos_inds, labels, reg_targets, flattened_hms

1.get_label_inds(centecnet.py)分配标签到特征图上

def _get_label_inds(self, gt_instances, shapes_per_level):
        '''
        Inputs:
            gt_instances: [n_i], sum n_i = N
            shapes_per_level: L x 2 [(h_l, w_l)]_L
        Returns:
            pos_inds: N'
            labels: N'
        '''
        pos_inds = []
        labels = []
        L = len(self.strides)         # 5
        B = len(gt_instances)    # bs
        shapes_per_level = shapes_per_level.long()
        loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long()         # [16128,  4032,  1008,   252,    66]
        level_bases = []
        s = 0
        for l in range(L):
            level_bases.append(s)
            s = s + B * loc_per_level[l]                                                                                               # [0, 64512, 80640, 84672, 85680 ]
        level_bases = shapes_per_level.new_tensor(level_bases).long()                    # [0, 64512, 80640, 84672, 85680 ]
        strides_default = shapes_per_level.new_tensor(self.strides).float()               #  [ 8, 16, 32, 64, 128 ]
        for im_i in range(B):
            targets_per_im = gt_instances[im_i]
            bboxes = targets_per_im.gt_boxes.tensor                                      # n x 4: (x1, y1, x2, y2)
            n = bboxes.shape[0]
            centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2)                       # n x 2
            centers = centers.view(n, 1, 2).expand(n, L, 2)                              # ( n, 5, 2 )
            strides = strides_default.view(1, L, 1).expand(n, L, 2)                # [ 8.,  16.,  32.,  64., 128 ] -->  ( n, 5, 2 )
            centers_inds = (centers / strides).long()                                           # n x 5 x 2
            Ws = shapes_per_level[:, 1].view(1, L).expand(n, L)                    # ( n, 5 )5层特征图的宽,单独拿出来
            pos_ind = level_bases.view(1, L).expand(n, L) + \
                       im_i * loc_per_level.view(1, L).expand(n, L) + \
                       centers_inds[:, :, 1] * Ws + \
                       centers_inds[:, :, 0]                                                                         # n x 5  : 把B个图片的5层特征图拉成直线,找到n个标签中心所在的索引
            is_cared_in_the_level = self.assign_fpn_level(bboxes)            # box 为绝对值  --> (n, 5):[True  False... ] 根据标签面积大小,确定该目标在哪层特征图
            pos_ind = pos_ind[is_cared_in_the_level].view(-1)                   # (n)
            label = targets_per_im.gt_classes.view(
                n, 1).expand(n, L)[is_cared_in_the_level].view(-1)                 # (n) class 绝对值

            pos_inds.append(pos_ind) # n'
            labels.append(label) # n'
        pos_inds = torch.cat(pos_inds, dim=0).long()                                  # 一个batch里,所有目标中心点所在索引 [516, 3692,7533,...55,209...,71433]
        labels = torch.cat(labels, dim=0)
        return pos_inds, labels                                                                               # N, N

2.assign_fpn_level:根据标签面积大小,确定该目标在哪层特征图

    def assign_fpn_level(self, boxes):
        '''
        Inputs:
            boxes: n x 4
            size_ranges: L x 2
        Return:
            is_cared_in_the_level: n x L
        '''
        size_ranges = boxes.new_tensor(
            self.sizes_of_interest).view(len(self.sizes_of_interest), 2)     # 5 x 2 :[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]
        crit = ((boxes[:, 2:] - boxes[:, :2]) **2).sum(dim=1) ** 0.5 / 2      # n 宽*高,得到面积
        n, L = crit.shape[0], size_ranges.shape[0]
        crit = crit.view(n, 1).expand(n, L)
        size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2)
        is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & \
            (crit <= size_ranges_expand[:, :, 1])
        return is_cared_in_the_level                                                                # n* 5 : [True  False...]

3.get_reg_targets:需要回归的gt(M,4)


    def _get_reg_targets(self, reg_targets, dist, mask, area):
        '''
          reg_targets (M x N x 4): long tensor
          dist (M x N)
          is_*: M x N
        '''
        dist[mask == 0] = INF * 1.0
        min_dist, min_inds = dist.min(dim=1) # M
        reg_targets_per_im = reg_targets[
            range(len(reg_targets)), min_inds] # M x N x 4 --> M x 4
        reg_targets_per_im[min_dist == INF] = - INF
        return reg_targets_per_im

4.create_agn_heatmaps_from_dist

    def _create_agn_heatmaps_from_dist(self, dist):
        '''
        TODO (Xingyi): merge it with _create_heatmaps_from_dist
        dist: M x N
        return:
          heatmaps: M x 1
        '''
        heatmaps = dist.new_zeros((dist.shape[0], 1))             # (M, 1)
        heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0])
        zeros = heatmaps < 1e-4
        heatmaps[zeros] = 0
        return heatmaps

2.计算损失

    def losses(
        self, pos_inds, labels, reg_targets, flattened_hms,
        logits_pred, reg_pred, agn_hm_pred):
        '''
        Inputs:
            pos_inds: N
            labels: N
            reg_targets: M x 4
            flattened_hms: M x C
            logits_pred: M x C
            reg_pred: M x 4
            agn_hm_pred: M x 1 or None
            N: number of positive locations in all images
            M: number of pixels from all FPN levels
            C: number of classes
        '''
        assert (torch.isfinite(reg_pred).all().item())
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(
            pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        losses = {
    
    }
        if not self.only_proposal:
            pos_loss, neg_loss = heatmap_focal_loss_jit(
                logits_pred, flattened_hms, pos_inds, labels,
                alpha=self.hm_focal_alpha, 
                beta=self.hm_focal_beta, 
                gamma=self.loss_gamma, 
                reduction='sum',
                sigmoid_clamp=self.sigmoid_clamp,
                ignore_high_fp=self.ignore_high_fp,
            )
            pos_loss = self.pos_weight * pos_loss / num_pos_avg
            neg_loss = self.neg_weight * neg_loss / num_pos_avg
            losses['loss_centernet_pos'] = pos_loss
            losses['loss_centernet_neg'] = neg_loss
        
        reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] >= 0).squeeze(1)       # 这里选出正样本832个(gt有200)
        reg_pred = reg_pred[reg_inds]
        reg_targets_pos = reg_targets[reg_inds]                                                             # (832, 4)
        reg_weight_map = flattened_hms.max(dim=1)[0]                                          # grid到中心点的距离 (M: 81840)
        reg_weight_map = reg_weight_map[reg_inds]                                                # (832)
        reg_weight_map = reg_weight_map * 0 + 1 \
            if self.not_norm_reg else reg_weight_map                                                     # (832)* [ 1 ] 
        reg_norm = max(reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1)
        reg_loss = self.reg_weight * self.iou_loss(
            reg_pred, reg_targets_pos, reg_weight_map,
            reduction='sum') / reg_norm
        losses['loss_centernet_loc'] = reg_loss

        if self.with_agn_hm:                                                                                                     # True
            cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M
            agn_pos_loss, agn_neg_loss = binary_heatmap_focal_loss_jit(
                agn_hm_pred, cat_agn_heatmap, pos_inds,
                alpha=self.hm_focal_alpha, 
                beta=self.hm_focal_beta, 
                gamma=self.loss_gamma,
                sigmoid_clamp=self.sigmoid_clamp,
                ignore_high_fp=self.ignore_high_fp,
            )
            agn_pos_loss = self.pos_weight * agn_pos_loss / num_pos_avg
            agn_neg_loss = self.neg_weight * agn_neg_loss / num_pos_avg
            losses['loss_centernet_agn_pos'] = agn_pos_loss
            losses['loss_centernet_agn_neg'] = agn_neg_loss
    
        if self.debug:
            print('losses', losses)
            print('total_num_pos', total_num_pos)
        return losses

2.detector_losses





猜你喜欢

转载自blog.csdn.net/qq_45752541/article/details/121006775
今日推荐