相比yolo,ssd,faster_rcnn等依靠大量anchor的检测网络,CenterNet是一种anchor-free的目标检测网络。Centernet将目标看作一个点,一个目标由一个特征点确定。centernet将输入的图片划分成若干个区域,每个区域存在一个特征点。centernet网络的预测结果会判断这个特征是否由对应的物体,以及物体的种类和置信度;同时还会对特征点进行调整获得物体的中心坐标;并且回归出物体的宽高。
一、网络结构
Centernet网络结构主要分成下面3个部分:
1、主干特征提取网络
Centernet用到的主干特征网络有多种,如一般是以Hourglass Network(主要用于人体姿态估计)、DLANet(Deep Layer Aggregation)或者Resnet等。原始的centernet输入输入网络图片的尺寸为3x512x512的时候,最后一个特征图的shape为(2048,16,16)。
2、上采样,获得高分辨率特征图
对于该部分,centernet利用三次反卷积进行上采样,这3个反卷积的输出通道数分别为256,128,64。每一次反卷积,特征图的高和宽会变为原来的两倍,因此,在进行三次反卷积上采样后,我们获得的特征图的高和宽变为原来的8倍,此时特征图的高和宽为128x128,通道数为64。此时我们获得了一个64x128x128的有效特征图(高分辨率特征图)。
3、Center head
通过上一步我们可以获得一个64x128x128的高分辨率特征图。这个特征图相当于将整个图片划分成128x128个区域,每个区域存在一个特征点,如果某个物体的中心落在这个区域,那么就由这个特征点来确定。
我们可以利用这个特征图进行三个卷积,分别是:
1)热力图预测,此时卷积的通道数为num_classes,最终结果为(num_classes ,128,128),代表每一个热力点是否有物体存在,以及物体的种类;
2)偏移预测,此时卷积的通道数为2,最终结果为(2,128,128),代表每一个物体中心距离热力点偏移的情况;
3)宽高预测,此时卷积的通道数为2,最终结果为(2,128,128),代表每一个物体宽高的预测情况;
二、heatmap(热力图)的理解和生成
CenterNet将目标当成一个点来检测,即用目标box的中心点来表示这个目标,预测目标的中心点偏移量(offset),宽高(size)来得到物体实际box,而heatmap则是表示分类信息。每一个类别都有一张heatmap,每一张heatmap上,若某个坐标处有物体目标的中心点,即在该坐标处产生一个keypoint(用高斯圆表示),如下图所示:
2、产生heatmap的步骤如下:
如下图左边是缩放后送入网络的图片,尺寸为512x512,右边是生成的heatmap图,尺寸为128x128(网络最后预测的heatmap尺度为128x128。其步骤如下:
1)将目标的box缩放到128x128的尺度上,然后求box的中心点坐标并取整,设为point
2)根据目标box大小计算高斯圆的半径,设为R
3)在heatmap图上,以point为圆心,半径为R填充高斯函数计算值。(point点处为最大值,沿着半径向外按高斯函数递减)
(注意:由于两个目标都是猫,属于同一类别,所以在同一张heatmap上。若还有一只狗,则狗的keypoint在另外一张heatmap上)
3、heatmap高斯函数半径的确定
heatmap上的关键点之所以采用二维高斯核来表示,是由于对于在目标中心点附近的一些点,期预测出来的box和gt_box的IOU可能会大于0.7,不能直接对这些预测值进行惩罚,需要温和一点,所以采用高斯核。借用下大佬们的解释,如下图所示:
关于高斯圆的半径确定,主要还是依赖于目标box的宽高,其计算方法为下图所示。 实际情况中会取IOU=0.7,即下图中的overlap=0.7作为临界值,然后分别计算出三种情况的半径,取最小值作为高斯核的半径r
这部分代码如下:
def gaussian_radius(det_size, min_overlap=0.7):
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
具体可以参考:https://zhuanlan.zhihu.com/p/388024445
三、loss函数
centernet的损失包括三部分,heatmap的损失,目标宽高预测损失,目标中心点偏移损失。
1、heatmap的损失
热力图的loss采用focal loss的思想进行运算,其中 α 和 β 是Focal Loss的超参数,而其中的N是图像的关键点数量(正样本个数),用于进行标准化。 α 和 β在这篇论文中和分别是2和4。在公式中,为预测值,Yxyc为标注真实值。
整体思想和Focal Loss类似,对于容易分类的样本,适当减少其训练比重也就是loss值。
具体可以参考:https://zhuanlan.zhihu.com/p/66048276
2、中心点偏移损失
损失函数公式如下, 其只对正样本的偏移值损失进行计算。其中 表示预测的偏移值,p为图片中目标中心点坐标,R为缩放尺度,为缩放后中心点的近似整数坐标
3、宽高预测损失值
损失函数公式如下,也是只对正样本的损失值计算,Spk为预测尺寸,Sk为真实尺寸
偏移值的loss和宽高的loss使用的是普通L1损失函数。偏移值预测和宽高预测都直接采用了特征图坐标的尺寸,也就是在0到128之内。由于wh宽高预测的loss会比较大,其loss乘上了一个系数,论文是0.1。偏移值预测的系数则为1。
四、解码
在推理阶段,首先将图像输入网络获取heatmap,然后采用3x3的MaxPool检测当前热点的值是否比周围的八个近邻点(八方位)都大(或者等于)的点(类似于anchor-based检测中nms的效果)。取top100个这样的点(所有类别一起)。使用回归得到的偏移和尺寸值进行计算得到bbox:
最终,选择大于阈值的中心点作为最终结果。
在原论文提到,centernet不像其它目标检测算法,在解码之后需要进行非极大值抑制,centernet的非极大值抑制在解码之前进行(利用3x3的池化核)。但在实际使用时发现,当目标为小目标时,确实可以不在解码之后进行非极大值抑制的后处理,如果目标为大目标,网络无法正确判断目标的中心时,还是需要进行非极大值抑制的后处理的。
本文参考以下文章
https://zhuanlan.zhihu.com/p/388024445
https://zhuanlan.zhihu.com/p/66048276
https://www.codenong.com/cs106869363/