OpenPCDet系列 | 7.1 KITTI数据集测试流程predicted_boxes预测

AnchorHeadTemplate.generate_predicted_boxes部分

测试流程的结构图如下所示:
在这里插入图片描述

generate_predicted_boxes函数一开始的数据传入为:
在这里插入图片描述

首先对于各类预测的特征图进行重新reshape处理,将anchor那一维度进行拼接操作,比如:(16, 248, 216, 42) -> (16, 321408, 7)。但是这里需要注意。特征预测的box信息是基于anchor的一个偏移,也就是编码后的偏移系数,所以需要对其进行按原路解码操作,才可以获得真实的box信息。

# 各种维度的reshape处理
anchors = torch.cat(self.anchors, dim=-3)   # (1, 248, 216, 3, 2, 7)
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]  # 3个类别+2个方向 在特征图上的总anchor数 321408
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)     # (16, 321408, 7)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() \
    if not isinstance(cls_preds, list) else cls_preds   # (16, 248, 216, 18) -> (16, 321408, 3)
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) \
    else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)    # (16, 248, 216, 42) -> (16, 321408, 7)
# 解码回去
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)   # 根据pred和anchor解码为正常的尺寸 (16, 321408, 7)

如果存在方向预测特征,同样对其进行reshape处理。这里的预测特征(16, 321408, 2)表示每个anchor对两个方向的预测概率,那么这里需要选择较高的概率的那个索引。torch.max函数的第一个返回结果是较高的数值,第二个返回的结果是较高数值的索引。所以,这里将预测特征图根据概率转换为01的预测结果。

if dir_cls_preds is not None:
    dir_offset = self.model_cfg.DIR_OFFSET      # 0.78539
    dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET  # 0
    dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) \
        else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)  # (16, 321408, 2)
    # 确定正向还是反向
    dir_labels = torch.max(dir_cls_preds, dim=-1)[1]     # (16, 321408)

最后,对角度进行限制到0-π之间,构建出准确的gt偏航角度。最后将真实预测的box信息以及label信息返回。特征维度分别是:(16, 321408, 7)和(16, 321408, 3)

period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)  # pi
dir_rot = common_utils.limit_period(    # 限制在0到pi之间
    batch_box_preds[..., 6] - dir_offset, dir_limit_offset, period
)
# period * dir_labels.to(batch_box_preds.dtype) 如果label为1,则为π;否则仍然保存0;
batch_box_preds[..., 6] = dir_rot + dir_offset + period * dir_labels.to(batch_box_preds.dtype)

return batch_cls_preds, batch_box_preds

最后将特征存储在字典中:

在这里插入图片描述


猜你喜欢

转载自blog.csdn.net/weixin_44751294/article/details/130597828
7.1