CenterNet: Objects as points 算法及作者代码walk through

相信论文的大体意思大家都有看过很多介绍,论文通过预测目标中心点和目标w和h来得到检测框,而且经过测试,该算法的框预测明显优于Yolov3,结果就不贴了。

Loss

在这里插入图片描述
上面是总的Loss ,分别包含Lk中心点heatmap预测loss,Lsize:object size Loss,Loff: 由于下采样导致的离散化错误,因此增加了中心点的offset预测,论文中解释如下:

To recover the discretization error caused by the output stride, we additionally predict a local offset for each center point.

跟我们的预想一样,center point loss 和 object size loss,只是为了克服4倍下采样误差,增加了一个offset loss。
代码:loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss
hm_loss: Lk, wh_loss: Lsize, off_loss:Loff.

Lk loss

在这里插入图片描述
Ground Truth 为以目标中心点为中心,产生的高斯分布,大小为[W/R, H/R, C]。
中心点公式很简单,就是x,y坐标除R(Stride)。
而ground Truth为
在这里插入图片描述
而产生的数据如下图所示
在这里插入图片描述
参考网址:https://zhuanlan.zhihu.com/p/66048276

代码

  1. Groud Truth生成
	# 下采样后的特征图大小
    output_h = input_h // self.opt.down_ratio
    output_w = input_w // self.opt.down_ratio
    num_classes = self.num_classes
    trans_output = get_affine_transform(c, s, 0, [output_w, output_h])
	# 和Yolo很像,生成zeros的gt,每个点存在目标的可能,类似one hot结构
    hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
    # 第一个参数代表最多存在的目标个数,后面的2代表width和height
    wh = np.zeros((self.max_objs, 2), dtype=np.float32)
    dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
    # size 回归
    reg = np.zeros((self.max_objs, 2), dtype=np.float32)
    ind = np.zeros((self.max_objs), dtype=np.int64)
    reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
    cat_spec_wh = np.zeros((self.max_objs, num_classes * 2), dtype=np.float32)
    cat_spec_mask = np.zeros((self.max_objs, num_classes * 2), dtype=np.uint8)

TODO 再看看代码再写,Sorry

猜你喜欢

转载自blog.csdn.net/weixin_39610043/article/details/94721395
今日推荐