centerpoint论文和代码解读

目录

一、序论

二、论文结构

三、代码


论文地址: https://arxiv.org/pdf/2006.11275.pdf

 代码地址:tianweiy/CenterPoint (github.com)

一、序论

centorpoint是一种anchor-free的方法,直接预测物体的中心点,然后直接回归其whl,省去了anchor与GT匹配过程(传统的anchor-base方法需要计算GT和anchor的iou进行分配),同时基于点的预测方便下游跟踪等任务的进行。论文最后的实验表明,该方法对于物体的旋转角度的学习更强一点。因为初始化只有一个点,强迫模型去学习更多的旋转角度信息。反之,anchor-base的方法因为有anchor的先验,所以模型更容易收敛。

二、论文结构

 

整体的网络架构和pointpillar很像,主要的改动地方在于head部分是anchor-free的。所以我们主要分析的也就是head部分。 

前面的部分,点云经过VFE处理,scatter投影到BEV上,使用FPN的neck对其进行处理得到[B,C,H,W],然后通过一个conv对通道数进行调整,分别经过五个头(其实就是一堆卷积+一个卷积把channel降到需要的维度),得到reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]。预测的reg是在一个像素内的偏移,主要是为了

推理时:将dim求指数,根据rot的正余弦值得到角度,将reg与meshgrid生成的坐标相加得到特征图上的绝对坐标。将他们拼接成[B,H*W,7]的box形式,同时对hm求sigmoid,送入后处理,首先对heatmap在channel维度求max,得到其分数和label,根据类别阈值对hm求mask,看哪些能够保留,然后进行NMS过滤掉多余的框,这里我们就说一阶段的,论文里用的两阶段,还有一个box修正阶段。注意:centorpoint使用了NMS

训练时:首先要得到GT的hm和box,所以先0初始化hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500] 因为每个样本的GT数量不可能一样,所以有的多有的少,统一为500最多,用mask来表示是不是GT,遍历GT个数,根据类别生成相应的hm,高斯半径是根据wh的框的最小iou重叠度确定的,具体见说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核 - 知乎 (zhihu.com)(分三种,内切,外切,交叉),这里作者限定了高斯半径的最小值。然后看中心点落在哪个pillar里,求个整型做差得到偏移量。对whl求log,对角度求sincos组成anno_box,ind表示该物体中心点在H*W中的下标,cat表示该物体的类别。这样就得到了example。如何画高斯就是用指数的负dist次表示权重,这样离中心点越近,越接近1.

这时有了GT的hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500]

模型预测的reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]

对模型预测的hm进行sigmoid,并组成pred_box[B,8,H*W]这时要把pred_box根据ind用gather转换为[B,8,500],用L1loss计算。而hm则直接用Fastfocalloss计算。

三、代码

import logging
from collections import defaultdict
from torch import double, nn
import copy 


import torch
import numpy as np
import torch.nn.functional as F

from ...ops.iou3d_nms import iou3d_nms_cuda
from ..model_utils import model_nms_utils


class Sequential(torch.nn.Module):
    r"""A sequential container.
    Modules will be added to it in the order they are passed in the constructor.
    Alternatively, an ordered dict of modules can also be passed in.

    To make it easier to understand, given is a small example::

        # Example of using Sequential
        model = Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

        # Example of using Sequential with OrderedDict
        model = Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))

        # Example of using Sequential with kwargs(python 3.6+)
        model = Sequential(
                  conv1=nn.Conv2d(1,20,5),
                  relu1=nn.ReLU(),
                  conv2=nn.Conv2d(20,64,5),
                  relu2=nn.ReLU()
                )
    """

    def __init__(self, *args, **kwargs):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
        for name, module in kwargs.items():
            if sys.version_info < (3, 6):
                raise ValueError("kwargs only supported in py36+")
            if name in self._modules:
                raise ValueError("name exists.")
            self.add_module(name, module)

    def __getitem__(self, idx):
        if not (-len(self) <= idx < len(self)):
            raise IndexError("index {} is out of range".format(idx))
        if idx < 0:
            idx += len(self)
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __len__(self):
        return len(self._modules)

    def add(self, module, name=None):
        if name is None:
            name = str(len(self._modules))
            if name in self._modules:
                raise KeyError("name exists")
        self.add_module(name, module)

    def forward(self, input):
        # i = 0
        for module in self._modules.values():
            # print(i)
            input = module(input)
            # i += 1
        return input




def rotate_nms_pcdet(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
    """
    :param boxes: (N, 7) [x, y, z, l, w, h, theta]
    :param scores: (N)
    :param thresh:
    :return:
    """
    # transform back to pcdet's coordinate
    #将角度转换为openpcdet的坐标
    boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
    boxes[:, -1] = -boxes[:, -1] - np.pi /2

    order = scores.sort(0, descending=True)[1] #将这n个box根据分数从大到小排
    if pre_maxsize is not None:  #如果盒子大于阈值,取前max个
        order = order[:pre_maxsize]

    boxes = boxes[order].contiguous()

    keep = torch.LongTensor(boxes.size(0))

    if len(boxes) == 0:
        num_out =0
    else:
        num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)

    selected = order[keep[:num_out].cuda()].contiguous()

    if post_max_size is not None:
        selected = selected[:post_max_size]

    return selected 


def kaiming_init(
    module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
):
    assert distribution in ["uniform", "normal"]
    if distribution == "uniform":
        nn.init.kaiming_uniform_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    else:
        nn.init.kaiming_normal_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    if hasattr(module, "bias") and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def gaussian_radius(det_size, min_overlap=0.5):
    """
    compute gaussian radius by min_overlap, you can get principle in <<CenterNet :Objects as Points>> paper
    """
    height, width = det_size  #得到高宽

    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2

    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2

    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    return min(r1, r2, r3)

def gaussian2D(shape, sigma=1):
    """
    compute gaussian
    """
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]  #y[7,1]  x [1,7]

    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # [7,7],离原点越近越大
    h[h < np.finfo(h.dtype).eps * h.max()] = 0  #np.finfo(h.dtype).eps是指非负的最小值
    return h


def draw_umich_gaussian(heatmap, center, radius, k=1):
    """
    draw gaussian in heatmap
    """
    diameter = 2 * radius + 1 #radius
    # compute gaussian value
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) #是一个7*7的矩阵

    x, y = int(center[0]), int(center[1]) #获得整形的中点坐标

    height, width = heatmap.shape[0:2]

    # get gaussian map pos
    left, right = min(x, radius), min(width - x, radius + 1)  #如果xy落在heatmap的边上,离边的距离小于r,就要限制一下防止越界
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # get masked heatmap pos 
    masked_heatmap  = heatmap[y - top:y + bottom, x - left:x + right] # 得到我们要替换heatmap的位置
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] #得到可用高斯的范围

    # this is used for debug, actuly no use
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) #取两者中较大的部分
    return heatmap

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2) # 8
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #ind[B,500]--[B,500,1]--[B,500,8] 其表示物体在特征图上的索引
    feat = feat.gather(1, ind)  #根据ind在第一维度H*W找索引ind
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()  # [B,200,380,8]
    feat = feat.view(feat.size(0), -1, feat.size(3)) # [B,H*W,8]
    feat = _gather_feat(feat, ind)
    return feat

def _circle_nms(boxes, min_radius, post_max_size=83):
    """
    NMS according to center distance, no use now
    """
    keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]

    keep = torch.from_numpy(keep).long().to(boxes.device)

    return keep 


class RegLoss(nn.Module):
  '''Regression loss for an output tensor
    Arguments:
      output (batch x dim x h x w)
      mask (batch x max_objects)
      ind (batch x max_objects)
      target (batch x max_objects x dim)
  '''
  def __init__(self):
    super(RegLoss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    # output[B,8,200,380]  pred[B,500,8]
    # compute mask by ind as not all box number is same and not all grid in use
    pred = _transpose_and_gather_feat(output, ind)
    mask = mask.float().unsqueeze(2) 

    # use L1 loss 两者都是[B,500,8]乘上mask计算loss,然后在B和500维度求和,出来八维的loss
    loss = F.l1_loss(pred*mask, target*mask, reduction='none')
    loss = loss / (mask.sum() + 1e-4)
    loss = loss.transpose(2 ,0).sum(dim=2).sum(dim=1)
    return loss

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self):
    super(FastFocalLoss, self).__init__()

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    mask = mask.float()
    gt = torch.pow(1 - target, 4)
    # compute negtive loss in heatmap
    neg_loss = torch.log(1 - out) * torch.pow(out, 2) * gt
    neg_loss = neg_loss.sum()

    pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
    num_pos = mask.sum()

    # compute positive loss in heatmap
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    pos_loss = pos_loss.sum()
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos



def neg_loss_cornernet(pred, gt, mask=None):
    """
    Refer to https://github.com/tianweiy/CenterPoint.
    Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
    Args:
        pred: (B x 8 x h x w)
        gt: (B x 8 x h x w)
        mask: (B x h x w)
    Returns:
    """
    pos_inds = gt.eq(1).float() #有物体中心点的地方才为1
    neg_inds = gt.lt(1).float() #不是物体中心的为1

    neg_weights = torch.pow(1 - gt, 4) #[B,8,H,W]  #把负样本的权重设置的很小

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds #这样负样本loss会很低

    if mask is not None:
        mask = mask[:, None, :, :].float()
        pos_loss = pos_loss * mask
        neg_loss = neg_loss * mask
        num_pos = (pos_inds.float() * mask).sum()
    else:
        num_pos = pos_inds.float().sum()

    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos  #求完的loss之和除以正样本的个数
    return loss


class FocalLossCenterNet(nn.Module):
    """
    Refer to https://github.com/tianweiy/CenterPoint
    """
    def __init__(self):
        super(FocalLossCenterNet, self).__init__()
        self.neg_loss = neg_loss_cornernet

    def forward(self, out, target, mask=None):
        return self.neg_loss(out, target, mask=mask)



class AssignLabel(object):
    def __init__(self, **kwargs):
        """Return CenterNet training labels like heatmap, height, offset"""

        self.tasks = kwargs["tasks"] #assigner_cfg.target_assigner.tasks

        assigner_cfg = kwargs["cfg"]

        self.out_size_factor = assigner_cfg.out_size_factor # 2
        self.gaussian_overlap = assigner_cfg.gaussian_overlap # 0.1
        self._max_objs = assigner_cfg.max_objs  # 500
        self._min_radius = assigner_cfg.min_radius # 2
        # tasks
        self.class_names = self.tasks["class_names"] # 列表里是八个名字
        self.num_classes = self.tasks["num_class"]  # 8

    def __call__(self, res,  grid_size , voxel_size , pc_range):
        max_objs = self._max_objs   # 500

        feature_map_size = grid_size[:2] // self.out_size_factor  # 得到特征图的长宽
        
        draw_gaussian = draw_umich_gaussian
        # 分别是xyzhwl,yaw,类别
        gt_boxes = res['gt_boxes'].cpu().numpy() # 得到data_dict里的GT  [B,N,8]
        batch_size = res['batch_size']

        # hm is heatmap
        hms, anno_boxs, inds, masks, cats = [], [], [], [], []

        #jinmu: batch one by one compute now
        for batch_idx in range(batch_size):
            batch_box = gt_boxes[batch_idx,...]  #[n,8]
            batch_box_mask = batch_box[...,-1] != 0 # 因为n表示batch里一个样本最多的物体数,有些没有这么多
            #上面这句是指遍历n个物体,最后一维不为0表示有物体
            if np.all(batch_box_mask == False):
                batch_box_valid_num = 0
            else:  # batch_box_mask=[1,1,1,1,0,0,0,0,0]一维的话,np.where只返回列数
                batch_box_valid_num = np.where(batch_box_mask)[0].squeeze().max() + 1 #得到有几个物体

            # c, h, w  [8, 200,380]
            hm = np.zeros((len(self.class_names), feature_map_size[1], feature_map_size[0]),
                            dtype=np.float32)
            # [500, 8]
            anno_box = np.zeros((max_objs, 8), dtype=np.float32)
            # [500]
            ind = np.zeros((max_objs), dtype=np.int64)
            mask = np.zeros((max_objs), dtype=np.uint8) # [500]
            cat = np.zeros((max_objs), dtype=np.int64)  # [500]

            # should keep box number same in different frame to
            # compute in one time, but actualy different frame not 
            # has same box number, so should keep mask
            num_objs = min(batch_box_valid_num, max_objs)  #得到当前帧的物体个数

            for k in range(num_objs):
                cls_id = batch_box[k][-1] - 1  #cls的id
                l, w, h = batch_box[k][3], batch_box[k][4], batch_box[k][5]
                # 得到在特征图上的wl
                w, l = w / voxel_size[1] / self.out_size_factor, l / voxel_size[0] / self.out_size_factor
                if w > 0 and l > 0:  #根据长宽得到高斯半径,根据两个框的最小重叠区,建立r的方程求根,内切外切,一个内一个外
                    radius = gaussian_radius((l, w), min_overlap=self.gaussian_overlap) #wl是浮点数,超参为0.1,得到高斯半径
                    radius = max(self._min_radius, int(radius)) #确保最小的高斯半径为2

                    # 得到中心点在特征图上的坐标
                    x, y, z = batch_box[k][0], batch_box[k][1], batch_box[k][2]
                    coor_x, coor_y = (x - pc_range[0]) / voxel_size[0] / self.out_size_factor, \
                                        (y - pc_range[1]) / voxel_size[1] / self.out_size_factor
                    
                    ct = np.array([coor_x, coor_y], dtype=np.float32)  
                    ct_int = ct.astype(np.int32)  #变为整型

                    # throw out not in range objects to avoid out of array area when creating the heatmap
                    # if beyond range, then continue
                    if not (0 <= ct_int[0] < feature_map_size[0] and 0 <= ct_int[1] < feature_map_size[1]):
                        continue 

                    # draw gaussian in heatmap gt
                    draw_gaussian(hm[int(cls_id)], ct, radius) #画到相应类的heatmap上

                    new_idx = k #表示第k个物体
                    x, y = ct_int[0], ct_int[1]

                    cat[new_idx] = cls_id # 得到相应物体的类别
                    ind[new_idx] = y * feature_map_size[0] + x  # 得到该物体在特征图上的索引
                    mask[new_idx] = 1  #把相应位置的mask赋值为1
                    rot = batch_box[k][6]
                    # fill regression target, ct - (x,y) is x_offset and y_offset
                    # rot is yaw angle
                    anno_box[new_idx] = np.concatenate(
                        (ct - (x, y), z, np.log(batch_box[k][3:6]),
                        np.sin(rot), np.cos(rot)), axis=None)  #得到当前heatmap的xy偏移,whl,sincos,

            hms.append(hm)
            anno_boxs.append(anno_box)
            masks.append(mask)
            inds.append(ind)
            cats.append(cat)

        hms = torch.from_numpy(np.stack(hms)).cuda() #将数组沿着第0维堆叠
        anno_boxs = torch.from_numpy(np.stack(anno_boxs)).cuda()
        inds = torch.from_numpy(np.stack(inds)).cuda()
        cats = torch.from_numpy(np.stack(cats)).cuda()
        masks = torch.from_numpy(np.stack(masks)).cuda()
        # [B,8,h,w]   [B,500,8]  [B,500,1] [B,500,1] [B,500,1]
        example = {'hm': hms, 'anno_box': anno_boxs, 'ind': inds, 'mask': masks, 'cat': cats}

        return example


class SepHead(nn.Module):
    """
    this is seqhead that contains actual head like (heatmap) (lxoffset yoffset) (z) (dim) (cos(theta) sin(theta))
    """
    def __init__(
        self,
        in_channels,
        heads,
        head_conv=64,
        final_kernel=1,
        bn=False,
        init_bias=-2.19,
        **kwargs,
    ):
        super(SepHead, self).__init__(**kwargs)

        self.heads = heads # {'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
        for head in self.heads:  #遍历的是键
            classes, num_conv = self.heads[head] #根据键得到值,第一个最终的channel数,用来回归的,第二个是几个conv

            fc = Sequential()
            # layers number decided by config
            for i in range(num_conv-1):
                fc.add(nn.Conv2d(in_channels, head_conv,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))  #
                if bn:
                    fc.add(nn.BatchNorm2d(head_conv))
                fc.add(nn.ReLU())

            # output conv
            fc.add(nn.Conv2d(head_conv, classes,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))    
            # hm的偏置是固定的,其余的开明初始化
            if 'hm' in head:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, nn.Conv2d):
                        kaiming_init(m)
            # 每个头都有两个卷积,再接一个卷积用来得到预测结果channel维度
            # python method, 设置完可以用getattr通过head调用fc
            self.__setattr__(head, fc)
        

    def forward(self, x):
        ret_dict = dict()        
        for head in self.heads:
            ret_dict[head] = self.__getattr__(head)(x)
        #ret_dict是一个字典 reg:[B,2,200,380] height [B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        return ret_dict


class CenterHead(nn.Module):
    def __init__(
        self,
        model_cfg,
        input_channels=[128,],
        num_class=1,
        class_names=None,
        grid_size=[0.32,0.32,0.16],
        point_cloud_range=None,
        predict_boxes_when_training=False,
        logger=None,
        init_bias=-2.19,
        num_hm_conv=2,
    ):
        super(CenterHead, self).__init__()
        assert(len(class_names) == num_class)
        
        tasks = dict(num_class=num_class, class_names=class_names)
        self.label_assigner = AssignLabel(cfg=model_cfg.TARGET_ASSIGNER_CONFIG, tasks=tasks)
        
        self.out_size_factor = model_cfg.TARGET_ASSIGNER_CONFIG.out_size_factor # 2
        self.model_cfg = model_cfg

        self.class_names = [class_names] #class_name本来是一个列表现在[[a,b,c,,,,]]
        self.num_classes = [num_class]  # [8]

        self.code_weights = model_cfg.code_weights #[5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0]
        self.weight = model_cfg.weight # 0.25 
        
        self.in_channels = input_channels # 384

        #self.crit = FastFocalLoss()
        self.crit = FocalLossCenterNet()
        self.crit_reg = RegLoss()

        

        common_heads = model_cfg.common_heads #{'reg': [ 2, 2 ],'height': [ 1, 2 ],'dim': [ 3, 2 ],'rot': [ 2, 2 ]}

        self.box_n_dim = 9 if 'vel' in common_heads else 7  # 7
        self.use_direction_classifier = False 

        if not logger:
            logger = logging.getLogger("CenterHead")
        self.logger = logger

        logger.info(
            f"num_classes: {self.num_classes}"
        )

        # a shared convolution 
        share_conv_channel = 64 if "share_conv_channel" not in model_cfg else model_cfg.share_conv_channel # 64
        self.shared_conv = nn.Sequential(
            nn.Conv2d(self.in_channels, share_conv_channel,
            kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(share_conv_channel),
            nn.ReLU(inplace=True)
        )

        self.tasks = nn.ModuleList()
        print("Use HM Bias: ", init_bias)

        for num_cls in self.num_classes:  #[8]相当于就遍历一个8
            heads = copy.deepcopy(common_heads) 
            heads.update(dict(hm=(num_cls, num_hm_conv))) #{'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
            self.tasks.append(
                SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
            )

        self.frozen_param = model_cfg.FROZON_PARAM
        self.frozen_parameters()

        logger.info("Finish CenterHead Initialization")

    def forward(self, data_dict, *kwargs):

        x = data_dict['spatial_features_2d'] # [B, 384, 200, 380]
        x = self.shared_conv(x)  #先将channel变为64
        ret_dicts = []

        for task in self.tasks:
            ret_dicts.append(task(x))
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        data_dict['centerhead_preds'] = ret_dicts

        return data_dict

    def _sigmoid(self, x):
        y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
        return y

    def loss(self, data_dict, **kwargs):
        #是一个字典根据GT生成的 hm[B,8,H,W],anno_box [B,n,8] ind[B,n] mask[B,n] cat[B,n]
        example = self.label_assigner(data_dict, kwargs["grid_size"], kwargs["voxel_size"], kwargs["pc_range"])

        # get centerhead output reg[B,2,200,380] heigh[B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        preds_dicts = data_dict['centerhead_preds']

        assert(len(preds_dicts) == 1)
        # TODO refactor this
        preds_dict = preds_dicts[0] #本来是一个数组,得到字典
        
        # apply sigmoid for heatmap output
        preds_dict['hm'] = self._sigmoid(preds_dict['hm']) #对heatmap预测加上sigmoid,自定义的sigmoid,防止梯度消失
        # hm_loss = self.crit(
        #     preds_dict['hm'], 
        #     example['hm'], 
        #     example['ind'], 
        #     example['mask'], 
        #     example['cat']
        #     )
        
        hm_loss = self.crit(preds_dict['hm'], example['hm']) #使用focallosscenternet

        target_box = example['anno_box']
        # not care about vel as not vel now
        if 'vel' in preds_dict:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['vel'], preds_dict['rot']), dim=1)  
        else:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['rot']), dim=1)   

        # Regression loss for dimension, offset, height, rotation  得到长度为8的loss张量          
        box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'], example['ind'], target_box)
        box_loss = box_loss * box_loss.new_tensor(self.code_weights) #这样可以使后面的张量拥有和前面一样的属性
        
        reg_loss = box_loss[:2]
        height_loss = box_loss[2]
        dim_loss = box_loss[2:5]
        rot_loss = box_loss[5:]
        
        loc_loss = box_loss.sum()
        loc_loss *= self.weight

        # total loss
        loss = hm_loss + loc_loss
        #ret = {'loss': loss, 'hm_loss': hm_loss, 'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][0].float().sum()}
        # ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss, 
        #         'reg_loss': reg_loss, 'height_loss': height_loss, 
        #         'dim_loss': dim_loss, 'rot_loss': rot_loss}

        ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss}
        
        return ret
    
    def frozen_parameters(self):
        if self.frozen_param:
            for parameter in self.parameters():
                parameter.requires_grad = False

    @torch.no_grad()
    def predict(self, preds_dicts, test_cfg, **kwargs):
        """decode, nms, then return the detection result.
        """

        voxel_size = kwargs["voxel_size"]
        pc_range = kwargs["pc_range"]

        post_center_range = pc_range
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        preds_dicts = preds_dicts['centerhead_preds']

        if len(post_center_range) > 0:
            post_center_range = torch.tensor(
                post_center_range,
                dtype=preds_dicts[0]['hm'].dtype,
                device=preds_dicts[0]['hm'].device,
            )

        rets = []
        #jinmu now only support one task
        for task_id, preds_dict in enumerate(preds_dicts):
            # convert B C H W to B H W C 
            for key, val in preds_dict.items():
                preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()

            batch_size = preds_dict['hm'].shape[0]
            batch_hm = torch.sigmoid(preds_dict['hm'])

            # exp for dim output to keep dim > 0
            batch_dim = torch.exp(preds_dict['dim']) #dim is h, w, d

            # cos(theta) and sin(theta)
            batch_rots = preds_dict['rot'][..., 0:1]
            batch_rotc = preds_dict['rot'][..., 1:2]

            # x offset and y offset output
            batch_reg = preds_dict['reg']
            # z output
            batch_hei = preds_dict['height']

            # atan to recover true theta
            batch_rot = torch.atan2(batch_rots, batch_rotc) #根据正余弦得到角度

            batch, H, W, num_cls = batch_hm.size()

            # reshape for compute convenient
            batch_reg = batch_reg.reshape(batch, H*W, 2)
            batch_hei = batch_hei.reshape(batch, H*W, 1)

            batch_rot = batch_rot.reshape(batch, H*W, 1)
            batch_dim = batch_dim.reshape(batch, H*W, 3)
            batch_hm = batch_hm.reshape(batch, H*W, num_cls) #把hw放一块方便计算

            #compute x and y axies for each grid for later to recover lidar axies x y with 
            # x_offset and y_offset
            ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
            ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()
            xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()

            # x y  + x_offset y_offset to recover continuous x y value
            xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
            ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]

            xs = xs * self.out_size_factor * voxel_size[0] + pc_range[0]
            ys = ys * self.out_size_factor * voxel_size[1] + pc_range[1]

            # jinmu: not care aboud this as we has not vel output now
            if 'vel' in preds_dict:
                batch_vel = preds_dict['vel']
                batch_vel = batch_vel.reshape(batch, H*W, 2)
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
            else: 
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)

            if test_cfg.get('per_class_nms', False):
                pass 
            else:
                rets.append(self.post_processing(batch_box_preds, batch_hm, test_cfg, post_center_range)) 

        assert(len(rets) == 1) # only one task

        return rets[0]

    @torch.no_grad()
    def post_processing(self, batch_box_preds, batch_hm, test_cfg, post_center_range):
        batch_size = len(batch_hm)
        # batch_box_preds [B,H*W,7] batch_hm [B,H*W,8]
        prediction_dicts = []
        for i in range(batch_size):  #一个一个batch处理
            box_preds = batch_box_preds[i]
            hm_preds = batch_hm[i]

            # score and label is get as max operation in heatmap #在八个维度里取个max
            scores, labels = torch.max(hm_preds, dim=-1) #得到最大分数和最大分数的下标(也就是类别)形状都为[H*W]

            # score mask is get as > score_thresh
            #score_mask = scores > test_cfg.score_threshold 
            score_threshold = torch.tensor(test_cfg.score_threshold)[labels] #得到H*W对应类别的thresh
            score_mask = scores > score_threshold.cuda() #如果这个分数大于阈值,就判定为正样本

            # distance_mask means that noly keep 3d box center in some range
            # not use this in perception postprocess code
            distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
                & (box_preds[..., :3] <= post_center_range[3:]).all(1)

            # mask is intersection of two mask
            mask = distance_mask & score_mask 

            # get masked data
            box_preds = box_preds[mask] #得到H*W个box里符合要求的
            scores = scores[mask]
            labels = labels[mask]

            # get box for nms, each box in [x y z dx dy dz theta] format
            boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]

            # bev rotated box nms
            selected = rotate_nms_pcdet(boxes_for_nms, scores, 
                                thresh=test_cfg.nms.nms_iou_threshold,
                                pre_maxsize=test_cfg.nms.nms_pre_max_size,
                                post_max_size=test_cfg.nms.nms_post_max_size)

            # selected is box mask after nms
            selected_boxes = box_preds[selected]
            selected_scores = scores[selected]
            selected_labels = labels[selected]

            # fill result, selected_boxes: n * 7, selected_scores: n * 1,
            # selected_labels: n * 1
            record_dict = {
                'pred_boxes': selected_boxes,
                'pred_scores': selected_scores,
                'pred_labels': selected_labels + 1
            }

            prediction_dicts.append(record_dict)

        return prediction_dicts 

猜你喜欢

转载自blog.csdn.net/slamer111/article/details/130596258
今日推荐