OpenPCDet系列 | 7.2 KITTI数据集测试流程post_processing处理

Detector3DTemplate.post_processing部分

测试流程的结构图如下所示:
在这里插入图片描述

post_processing函数的传入为generate_predicted_boxes部分的输出结果:

在这里插入图片描述

在后处理函数的具体操作中,有三大步骤:

1)分别获取当前点云帧的label信息和真实场景的gt信息
其中需要将分类预测输出转化为概率,其中利用预测概率的最大值作为其置信度,给后续的nms处理

batch_mask = index

box_preds = batch_dict['batch_box_preds'][batch_mask]   # 获取第index个点云的预测gt信息 (321408, 7)
src_box_preds = box_preds   # (321408, 7)
cls_preds = batch_dict['batch_cls_preds'][batch_mask]   # 获取第index个点云的label信息 (321408, 3) 这里是one-hot编码形式的预测
src_cls_preds = cls_preds   # (321408, 3)

cls_preds = torch.sigmoid(cls_preds)    # 转化为logits

2)对当前点云帧进行nms处理,挑选出被选择的anchor索引

# torch.max()函数的第一个返回值是每行的最大值,第二个返回值是每行最大值的索引
cls_preds, label_preds = torch.max(cls_preds, dim=-1)   # 这里是输入cls_preds是已经过sigmoid处理,构建成为概率的
if batch_dict.get('has_class_labels', False):   # False
    label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
    label_preds = batch_dict[label_key][index]
else:
    label_preds = label_preds + 1   # 索引标签+1

# 调用gpu函数进行nms,返回的是被选择的索引和索引分数
selected, selected_scores = model_nms_utils.class_agnostic_nms(
    box_scores=cls_preds,   # 将每个预测的anchor的最大概率当做是置信度
    box_preds=box_preds,
    nms_config=post_process_cfg.NMS_CONFIG,
    score_thresh=post_process_cfg.SCORE_THRESH  # 0.1
)

if post_process_cfg.OUTPUT_RAW_SCORE:
    max_cls_preds, _ = torch.max(src_cls_preds, dim=-1)
    selected_scores = max_cls_preds[selected]

final_scores = selected_scores          # 预测分数
final_labels = label_preds[selected]    # 预测类别
final_boxes = box_preds[selected]       # 预测box信息

其中,这里的nms函数返回的是挑选的anchor索引以及其置信度(sigmoid处理后的预测值)。但是核心的nms处理函数用了c++代码进行编译处理。

def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
    """
    利用nms挑选出anchor的索引以及对应的置信度
    """
    # 首先根据置信度阈值过滤掉部分box
    src_box_scores = box_scores
    if score_thresh is not None:    # 这里的阈值设定为0.1
        scores_mask = (box_scores >= score_thresh)  # 阈值过滤
        box_scores = box_scores[scores_mask]    # (145)
        box_preds = box_preds[scores_mask]      # (145, 7)

    selected = []
    if box_scores.shape[0] > 0:    # 筛选后还存在object
        box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))    # 返回排序后的置信度大小以及其索引
        boxes_for_nms = box_preds[indices]  # 根据置信度大小排序
        keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
                boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
        )
        selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]  # 根据返回索引找出box索引值

    if score_thresh is not None:
        original_idxs = scores_mask.nonzero().view(-1)  # 原始大于置信度的索引
        # selected表示的box_scores的选择索引,经过这次索引,selected表示的是src_box_scores被选择的索引
        selected = original_idxs[selected]      # 在索引中挑选索引
    return selected, src_box_scores[selected]

3)进行找召回率计算

recall_dict = self.generate_recall_record(
    box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds,
    recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
    thresh_list=post_process_cfg.RECALL_THRESH_LIST
)    

猜你喜欢

转载自blog.csdn.net/weixin_44751294/article/details/130597889
7.2
今日推荐