OpenPCDet系列 | 7.PointPillars模型测试KITTI数据集流程解析

模型的测试流程

对于模型来说,训练过程是为了计算构建损失训练模型的参数,验证过程是为了测试模型当前参数的效果。所以,对于模型结构来说需要分别为测试过程和训练过程进行分别规划。在点云的3d检测中,这里主要体现在dense_head预测层中。对于模型来说,其与训练流程的区别结构图如下:
在这里插入图片描述

  • 对于dense_head处理的区别:
# 功能:构建PointPillar的dense head模块部分
class AnchorHeadSingle(AnchorHeadTemplate):
    def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
                 predict_boxes_when_training=True, **kwargs):
        super().__init__(   # 基类没有传参input_channels和voxel_size
            model_cfg=model_cfg, num_class=num_class,
            class_names=class_names, grid_size=grid_size,
            point_cloud_range=point_cloud_range,
            predict_boxes_when_training=predict_boxes_when_training
        )
        ......       

    def forward(self, data_dict):
        ......
        
        # 训练过程
        if self.training:
            targets_dict = self.assign_targets(     # 获取gt信息
                gt_boxes=data_dict['gt_boxes']
            )
            self.forward_ret_dict.update(targets_dict)   # 此时记录gt信息以及预测信息,来进行后续的loss计算

        # 测试过程
        if not self.training or self.predict_boxes_when_training:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=data_dict['batch_size'],
                cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
            )
            data_dict['batch_cls_preds'] = batch_cls_preds
            data_dict['batch_box_preds'] = batch_box_preds
            data_dict['cls_preds_normalized'] = False

        return data_dict    # 返回更新后的data_dict
  • 对于个模块处理后的算法流程区别:
# 功能:基于Detector3DTemplate构建PointPillar算法结构
class PointPillar(Detector3DTemplate):
    def __init__(self, model_cfg, num_class, dataset):
        """
        Args:
            model_cfg:   yaml配置文件的MODEL部分
            num_class:   类别数目(kitti数据集一般用3个类别:'Car', 'Pedestrian', 'Cyclist')
            dataset:     训练数据集
        """
        super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)   # 初始化基类

        # 网络的各处理模块已经存储在self中(vfe / map_to_bev / backbone_2d ...)
        self.module_list = self.build_networks()    # 真正构建模型的处理函数,Detector3DTemplate的子函数

    def forward(self, batch_dict):
        # 各模块分别进行特征处理,更新batch_dict,然后将预测信息与gt信息保存在forward_ret_dict字典中来进行后续的损失计算
        for cur_module in self.module_list:
            batch_dict = cur_module(batch_dict)

        # 训练过程进行损失计算
        if self.training:
            loss, tb_dict, disp_dict = self.get_training_loss()    # 损失计算

            ret_dict = {
    
    
                'loss': loss
            }
            return ret_dict, tb_dict, disp_dict
            
        # 测试过程进行后处理返回预测结果
        else:
            pred_dicts, recall_dicts = self.post_processing(batch_dict)
            return pred_dicts, recall_dicts

    ......

下面会分别对这些核心函数模块进行记录。


1. AnchorHeadTemplate.generate_predicted_boxes部分

generate_predicted_boxes函数一开始的数据传入为:

在这里插入图片描述

最后将特征存储在字典中:

在这里插入图片描述


2. Detector3DTemplate.post_processing部分

在AnchorHeadTemplate.generate_predicted_boxes函数处理完之后更新的batch_dict字典就是这边后处理函数的输入。可以说,这个batch_dict一直贯穿着整个模型的前向传播过程,包括测试阶段的后处理部分。在训练过程中就用不到这个batch_dict。


3. KittiDataset.generate_prediction_dicts部分

这部分进行kitti数据的预测处理模块


4. KittiDataset.evaluation部分

这部分进行kitti数据的具体验证模块


猜你喜欢

转载自blog.csdn.net/weixin_44751294/article/details/130556374