CenterNet 后处理过程及源码解析

目录

1、写在前面

2、后处理源码解析

2.1 CenterNet推理过程

2.2 后处理源码解析

2.2.1 NMS

2.2.2 Top K

2.2.3 提取reg和wh

2.2.4 是否对每一类分别设置wh

2.2.5 构造bounding boxes并组合置信度得分、类型信息

3、写在后面


1、写在前面

CenterNet是2019年的一篇论文Objects as Points中提出的网络。由于其结构简单、可扩展性强、模型转换方便的优点,至今仍很受欢迎。其GitHub地址为:https://github.com/xingyizhou/CenterNet

整个模型包括三部分:backbone、上采样、heads。其中,backbone部分即为传统的分类模型去掉fc层,用于提取高级语义特征;上采样部分用于从低分辨率的feature map恢复部分分辨率,提高最终输出的feature map的分辨率;heads部分则得到三部分输出:每一类目标中心点的heatmap、包含每个目标尺寸(W、H)的wh、中心点偏移量offset。

对于最终输出的heatmap、wh、offset,其shape分别为:[B, C, H, W]、[B, 2, H, W]、[B, 2, H, W]。其中,B为batch大小,C为类别个数、H和W为feature map的高和宽。

我们今天要讲的后处理,即是对heatmap、wh、offset进行的处理,以得到最终bounding-boxes形式的输出,主要以源码中“src/lib/models/decode.py”的代码为讲解对象,介绍整个后处理的逻辑,并对其中的NMS、topK等过程进行逐行解释。

2、后处理源码解析

2.1 CenterNet推理过程

推理的完整流程为:

  • 输入一张或者一个batch的图片,经过backbone(这里包括下采样和上采样)后输出feature map,尺寸为原图尺寸的1/4;
  • 然后送入三个分支,分别得到heatmap、wh、offset;
  • 最后通过后处理过程,得到bounding boxes。

而后处理仅在在推理阶段使用。

推理时,通过以下代码,可以得到hm、wh、reg,也即上文所说的heatmap、wh、offset:

  def process(self, images, return_time=False):
    with torch.no_grad():
      output = self.model(images)[-1]
      hm = output['hm'].sigmoid_()
      wh = output['wh']
      reg = output['reg'] if self.opt.reg_offset else None

然后,将这三者送入后处理流程:

dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)

2.2 后处理源码解析

ctdet_decode函数定义在“src/lib/models/decode.py”中:

def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
    batch, cat, height, width = heat.size()

    # 2.2.1 NMS
    # 如果参数heat传进来之前没有进行sigmoid,就需要先使用下面的方法归一化heat中的值
    # heat = torch.sigmoid(heat)
    # perform nms on heatmaps
    heat = _nms(heat)
      
    # 2.2.2 TopK
    scores, inds, clses, ys, xs = _topk(heat, K=K)

    # 2.2.3 提取reg和wh
    if reg is not None:
      reg = _transpose_and_gather_feat(reg, inds)
      reg = reg.view(batch, K, 2)
      xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
      ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
    else:
      xs = xs.view(batch, K, 1) + 0.5
      ys = ys.view(batch, K, 1) + 0.5
    wh = _transpose_and_gather_feat(wh, inds)
    
    # 2.2.4 是否对每一类分别设置wh
    if cat_spec_wh:
      wh = wh.view(batch, K, cat, 2)
      clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
      wh = wh.gather(2, clses_ind).view(batch, K, 2)
    else:
      wh = wh.view(batch, K, 2)

    # 2.2.5 构造bounding boxes并组合置信度得分、类型信息
    clses  = clses.view(batch, K, 1).float()
    scores = scores.view(batch, K, 1)
    bboxes = torch.cat([xs - wh[..., 0:1] / 2, 
                        ys - wh[..., 1:2] / 2,
                        xs + wh[..., 0:1] / 2, 
                        ys + wh[..., 1:2] / 2], dim=2)
    detections = torch.cat([bboxes, scores, clses], dim=2)
      
    return detections

下面,根据该函数的每个部分,分别进行解析。

2.2.1 NMS

从上述代码可以看出,首先需要对heat进行NMS处理,这里的NMS与传统的Anchor-based的检测算法不同,Anchor-based类算法的NMS是基于IOU进行的过滤,而CenterNet里面的NSM及其简单,仅仅是提取heatmap中的峰值,仅用一个3*3的maxpooling即可。NMS的代码如下,每一句我都添加了注释:

def _nms(heat, kernel=3):
    
    # 设置padding值,使得经过maxpooling后尺寸不变
    pad = (kernel - 1) // 2
    
    # 利用maxpooling将峰值保留,而非峰值部分的信息被抹去了
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=pad)
    
    # 将峰值的位置索引设为True,非峰值为False
    keep = (hmax == heat).float()

    # 返回结果中,峰值部分值保留,非峰值部分值为0
    return heat * keep

2.2.2 Top K

经过NMS之后,紧接着就是topK操作了,这一步的目的是得到置信度排名前K个中心点的置信度得分、索引、类别、中心点坐标,代码如下,同样都添加了注释:

# 这里的scores是传入的heatmap
def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
    
    # 首先将scores的H和W两个维度合并展开,然后利用torch.topk函数得到排序后的值及其索引,结果为:
    # topk_scores   size:(B,num_cls, K)   float
    # topk_inds     size:(B,num_cls, K)   int
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    # 得到中心点坐标(x,y), 其中inds = y * W + x, 现在知道inds,逆过程即为求x,y
    topk_inds = topk_inds % (height * width) # 取%是为了使索引不至于超出范围
    topk_ys   = (topk_inds / width).int().float()
    topk_xs   = (topk_inds % width).int().float()
      
    # 对每个峰值,确定其类别(当多个类的中心点重合时,只能保留一个置信度最大的)
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
    # 输出的维度:
    # topk_score    size:(B, K)   float
    # topk_ind      size:(B, K)   int

    # 得到top K个目标的类别,在多个类别时,topk_ind除以K可以将同类的点统一到一个值
    topk_clses = (topk_ind / K).int()
    # top K 个点的索引
    topk_inds = _gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    # top K个中心点坐标
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

2.2.3 提取reg和wh

接下来,就是提取topK个中心点对应的reg和wh。如果使用了reg也即中心点偏移,则先提取reg然后再将中心点坐标加上偏移量;如果没有使用reg,则直接将中心点坐标加上0.5的偏移。对reg和wh的topK提取主要使用了“_transpose_and_gather_feat”方法,其代码如下,同样逐行进行了注释说明:

def _gather_feat(feat, ind, mask=None):
    # 第三个维度的值,如reg是中心点(x,y)的偏移量,为2
    dim  = feat.size(2)
    # 将ind从[B, K]转换为[B, K, 1], 然后使用expand扩展为[B, K, dim]
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    # 从第二个维度,按照ind提供的的索引,提取对应的元素
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

def _transpose_and_gather_feat(feat, ind):
    # 对输入feat先进行维度置换,将第二个维度换到最后面
    feat = feat.permute(0, 2, 3, 1).contiguous()
    # 将中间两个维度合并到一维,并后的维度其维数为feature map 的宽高乘积,
    feat = feat.view(feat.size(0), -1, feat.size(3))
    # 调用_gather_feat方法,从feat提取ind指定索引位置的元素
    feat = _gather_feat(feat, ind)
    return feat

2.2.4 是否对每一类分别设置wh

接下来的cat_spec_wh代表是否对每个类别分别设置了wh。如果是,则wh原来的维度其实是[B, C*2, H, W],经过_transpose_and_gather_feat处理之后是[B, K, C*2],所以就需要将最后一个维度展开,变为[B, K, C, 2],然后调用gather方法从C这个维度开始按照clses_ind提取topK个中心点对应的宽高信息;如果否,wh经过_transpose_and_gather_feat处理之后就直接是[B, K, 2]了,else后面的那句“wh = wh.view(batch, K, 2)”我个人感觉应该就可以不用了(对这一句不知道理解的对不对,有知道的小伙伴烦请留言告知~)。

2.2.5 构造bounding boxes并组合置信度得分、类型信息

最后一部分,就是先把clses、scores展开到[B, K, 1],然后利用中心点坐标(xs, ys)和宽高信息wh,组合成为bounding boxes,再把boxes、scores、clses拼接起来,作为函数返回值。

3、写在后面

至此,CenterNet的后处理源码就解析完了,主要结合的是官方源码的后处理部分,文章对每一步涉及到的操作都做了注解,相信看一遍就会明白;最好的建议,还是结合代码,使用一个示例运行一遍,看一看每一步运算后,tensor是如何变化的,我也是这样来一步步做的,也因此有了这篇解析。这篇文章的目的就是将后处理过程吃透,便于我们在推理过程的使用,同时作为一个备忘,用于我后续回看。那对于网络的forward部分,我们则可以将其转为ONNX、TensorRT等形式加速计算,然后将结果利用该后处理过程得到我们想要的边框、类别、置信度等信息。

猜你喜欢

转载自blog.csdn.net/oYeZhou/article/details/111224567
今日推荐