Serie OpenPCDet | 7. Prueba del modelo PointPillars Análisis del proceso del conjunto de datos KITTI

Proceso de prueba del modelo

Para el modelo, el proceso de entrenamiento consiste en calcular los parámetros para construir el modelo de entrenamiento de pérdidas, y el proceso de verificación consiste en probar el efecto de los parámetros actuales del modelo. Por lo tanto, para la estructura del modelo, es necesario planificar por separado el proceso de prueba y el proceso de entrenamiento . En la detección 3d de nubes de puntos, esto se refleja principalmente en la capa de predicción dense_head. Para el modelo, la diferencia entre este y el proceso de entrenamiento es la siguiente:
inserte la descripción de la imagen aquí

  • Para la diferencia en el procesamiento de dense_head:
# 功能:构建PointPillar的dense head模块部分
class AnchorHeadSingle(AnchorHeadTemplate):
    def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
                 predict_boxes_when_training=True, **kwargs):
        super().__init__(   # 基类没有传参input_channels和voxel_size
            model_cfg=model_cfg, num_class=num_class,
            class_names=class_names, grid_size=grid_size,
            point_cloud_range=point_cloud_range,
            predict_boxes_when_training=predict_boxes_when_training
        )
        ......       

    def forward(self, data_dict):
        ......
        
        # 训练过程
        if self.training:
            targets_dict = self.assign_targets(     # 获取gt信息
                gt_boxes=data_dict['gt_boxes']
            )
            self.forward_ret_dict.update(targets_dict)   # 此时记录gt信息以及预测信息,来进行后续的loss计算

        # 测试过程
        if not self.training or self.predict_boxes_when_training:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=data_dict['batch_size'],
                cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
            )
            data_dict['batch_cls_preds'] = batch_cls_preds
            data_dict['batch_box_preds'] = batch_box_preds
            data_dict['cls_preds_normalized'] = False

        return data_dict    # 返回更新后的data_dict
  • Para la diferencia en el flujo del algoritmo después del procesamiento por cada módulo:
# 功能:基于Detector3DTemplate构建PointPillar算法结构
class PointPillar(Detector3DTemplate):
    def __init__(self, model_cfg, num_class, dataset):
        """
        Args:
            model_cfg:   yaml配置文件的MODEL部分
            num_class:   类别数目(kitti数据集一般用3个类别:'Car', 'Pedestrian', 'Cyclist')
            dataset:     训练数据集
        """
        super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)   # 初始化基类

        # 网络的各处理模块已经存储在self中(vfe / map_to_bev / backbone_2d ...)
        self.module_list = self.build_networks()    # 真正构建模型的处理函数,Detector3DTemplate的子函数

    def forward(self, batch_dict):
        # 各模块分别进行特征处理,更新batch_dict,然后将预测信息与gt信息保存在forward_ret_dict字典中来进行后续的损失计算
        for cur_module in self.module_list:
            batch_dict = cur_module(batch_dict)

        # 训练过程进行损失计算
        if self.training:
            loss, tb_dict, disp_dict = self.get_training_loss()    # 损失计算

            ret_dict = {
    
    
                'loss': loss
            }
            return ret_dict, tb_dict, disp_dict
            
        # 测试过程进行后处理返回预测结果
        else:
            pred_dicts, recall_dicts = self.post_processing(batch_dict)
            return pred_dicts, recall_dicts

    ......

Estos módulos de funciones básicas se registrarán por separado a continuación.


1. AnchorHeadTemplate.generate_predicted_boxes部分

La entrada de datos inicial de la función generate_predicted_boxes es:

inserte la descripción de la imagen aquí

Finalmente almacene las características en un diccionario:

inserte la descripción de la imagen aquí


2. Detector3DTemplate.post_procesamiento parte

El diccionario batch_dict actualizado después de que se procesa la función AnchorHeadTemplate.generate_predicted_boxes es la entrada de la función de posprocesamiento aquí. Se puede decir que este batch_dict se ejecuta a través del proceso de propagación hacia adelante de todo el modelo, incluida la parte de posprocesamiento de la fase de prueba. Este batch_dict no se usa durante el entrenamiento.


3. Parte KittiDataset.generate_prediction_dicts

Esta parte realiza el módulo de procesamiento de predicción de datos kitti


4. KittiDataset.parte de evaluación

Esta parte es el módulo de verificación específico de los datos de kitti.


Supongo que te gusta

Origin blog.csdn.net/weixin_44751294/article/details/130556374
Recomendado
Clasificación