Centernet算法

相比yolo,ssd,faster_rcnn等依靠大量anchor的检测网络,CenterNet是一种anchor-free的目标检测网络。Centernet将目标看作一个点,一个目标由一个特征点确定。centernet将输入的图片划分成若干个区域,每个区域存在一个特征点。centernet网络的预测结果会判断这个特征是否由对应的物体,以及物体的种类和置信度;同时还会对特征点进行调整获得物体的中心坐标;并且回归出物体的宽高。

一、网络结构

Centernet网络结构主要分成下面3个部分:

1、主干特征提取网络

Centernet用到的主干特征网络有多种,如一般是以Hourglass Network(主要用于人体姿态估计)、DLANet(Deep Layer Aggregation)或者Resnet等。原始的centernet输入输入网络图片的尺寸为3x512x512的时候,最后一个特征图的shape为(2048,16,16)。

2、上采样,获得高分辨率特征图

对于该部分,centernet利用三次反卷积进行上采样,这3个反卷积的输出通道数分别为256,128,64。每一次反卷积,特征图的高和宽会变为原来的两倍,因此,在进行三次反卷积上采样后,我们获得的特征图的高和宽变为原来的8倍,此时特征图的高和宽为128x128,通道数为64。此时我们获得了一个64x128x128的有效特征图(高分辨率特征图)。

3、Center head

通过上一步我们可以获得一个64x128x128的高分辨率特征图。这个特征图相当于将整个图片划分成128x128个区域,每个区域存在一个特征点,如果某个物体的中心落在这个区域,那么就由这个特征点来确定。

我们可以利用这个特征图进行三个卷积,分别是:

1)热力图预测,此时卷积的通道数为num_classes,最终结果为(num_classes ,128,128),代表每一个热力点是否有物体存在,以及物体的种类;

2)偏移预测,此时卷积的通道数为2,最终结果为(2,128,128),代表每一个物体中心距离热力点偏移的情况;

3)宽高预测,此时卷积的通道数为2,最终结果为(2,128,128),代表每一个物体宽高的预测情况;

二、heatmap(热力图)的理解和生成

1、heatmap的理解

CenterNet将目标当成一个点来检测,即用目标box的中心点来表示这个目标,预测目标的中心点偏移量(offset),宽高(size)来得到物体实际box,而heatmap则是表示分类信息。每一个类别都有一张heatmap,每一张heatmap上,若某个坐标处有物体目标的中心点,即在该坐标处产生一个keypoint(用高斯圆表示),如下图所示:

 

 2、产生heatmap的步骤如下:

如下图左边是缩放后送入网络的图片,尺寸为512x512,右边是生成的heatmap图,尺寸为128x128(网络最后预测的heatmap尺度为128x128。其步骤如下:

1)将目标的box缩放到128x128的尺度上,然后求box的中心点坐标并取整,设为point

2)根据目标box大小计算高斯圆的半径,设为R

3)在heatmap图上,以point为圆心,半径为R填充高斯函数计算值。(point点处为最大值,沿着半径向外按高斯函数递减)

(注意:由于两个目标都是猫,属于同一类别,所以在同一张heatmap上。若还有一只狗,则狗的keypoint在另外一张heatmap上)

3、heatmap高斯函数半径的确定

  heatmap上的关键点之所以采用二维高斯核来表示,是由于对于在目标中心点附近的一些点,期预测出来的box和gt_box的IOU可能会大于0.7,不能直接对这些预测值进行惩罚,需要温和一点,所以采用高斯核。借用下大佬们的解释,如下图所示:

  关于高斯圆的半径确定,主要还是依赖于目标box的宽高,其计算方法为下图所示。 实际情况中会取IOU=0.7,即下图中的overlap=0.7作为临界值,然后分别计算出三种情况的半径,取最小值作为高斯核的半径r

 

这部分代码如下:

def gaussian_radius(det_size, min_overlap=0.7):

    height, width = det_size

    a1 = 1

    b1 = (height + width)

    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)

    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)

    r1 = (b1 + sq1) / 2

    a2 = 4

    b2 = 2 * (height + width)

    c2 = (1 - min_overlap) * width * height

    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)

    r2 = (b2 + sq2) / 2

    a3 = 4 * min_overlap

    b3 = -2 * min_overlap * (height + width)

    c3 = (min_overlap - 1) * width * height

    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)

    r3 = (b3 + sq3) / 2

    return min(r1, r2, r3)  

具体可以参考:https://zhuanlan.zhihu.com/p/388024445

三、loss函数

centernet的损失包括三部分,heatmap的损失,目标宽高预测损失,目标中心点偏移损失。

1、heatmap的损失

热力图的loss采用focal loss的思想进行运算,其中 α 和 β 是Focal Loss的超参数,而其中的N是图像的关键点数量(正样本个数),用于进行标准化。 α 和 β在这篇论文中和分别是2和4。在公式中,为预测值,Yxyc为标注真实值。

整体思想和Focal Loss类似,对于容易分类的样本,适当减少其训练比重也就是loss值。

具体可以参考:https://zhuanlan.zhihu.com/p/66048276

2、中心点偏移损失

损失函数公式如下, 其只对正样本的偏移值损失进行计算。其中 表示预测的偏移值,p为图片中目标中心点坐标,R为缩放尺度,为缩放后中心点的近似整数坐标

  

3、宽高预测损失值

​损失函数公式如下,也是只对正样本的损失值计算,Spk为预测尺寸,​Sk为真实尺寸

偏移值的loss和宽高的loss使用的是普通L1损失函数。偏移值预测和宽高预测都直接采用了特征图坐标的尺寸,也就是在0到128之内。由于wh宽高预测的loss会比较大,其loss乘上了一个系数,论文是0.1。偏移值预测的系数则为1。

四、解码

在推理阶段,首先将图像输入网络获取heatmap,然后采用3x3的MaxPool检测当前热点的值是否比周围的八个近邻点(八方位)都大(或者等于)的点(类似于anchor-based检测中nms的效果)。取top100个这样的点(所有类别一起)。使用回归得到的偏移和尺寸值进行计算得到bbox:

最终,选择大于阈值的中心点作为最终结果。

在原论文提到,centernet不像其它目标检测算法,在解码之后需要进行非极大值抑制,centernet的非极大值抑制在解码之前进行(利用3x3的池化核)。但在实际使用时发现,当目标为小目标时,确实可以不在解码之后进行非极大值抑制的后处理,如果目标为大目标,网络无法正确判断目标的中心时,还是需要进行非极大值抑制的后处理的。

本文参考以下文章

https://zhuanlan.zhihu.com/p/388024445

https://zhuanlan.zhihu.com/p/66048276

https://www.codenong.com/cs106869363/

猜你喜欢

转载自blog.csdn.net/wanchengkai/article/details/128707828
今日推荐