OpenPCDet シリーズ | 7. PointPillars モデルテスト KITTI データセットプロセス分析

モデルのテストプロセス

モデルの場合、トレーニング プロセスは損失トレーニング モデルを構築するためのパラメーターを計算することであり、検証プロセスはモデルの現在のパラメーターの効果をテストすることです。そのため、モデルの構造については、テスト工程とトレーニング工程を分けて計画する必要があります点群の 3D 検出では、これは主に dense_head 予測レイヤーに反映されます。モデルに関して、トレーニング プロセスとの違いは次のとおりです。
ここに画像の説明を挿入

  • dense_head 処理の違いについては、次のとおりです。
# 功能:构建PointPillar的dense head模块部分
class AnchorHeadSingle(AnchorHeadTemplate):
    def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
                 predict_boxes_when_training=True, **kwargs):
        super().__init__(   # 基类没有传参input_channels和voxel_size
            model_cfg=model_cfg, num_class=num_class,
            class_names=class_names, grid_size=grid_size,
            point_cloud_range=point_cloud_range,
            predict_boxes_when_training=predict_boxes_when_training
        )
        ......       

    def forward(self, data_dict):
        ......
        
        # 训练过程
        if self.training:
            targets_dict = self.assign_targets(     # 获取gt信息
                gt_boxes=data_dict['gt_boxes']
            )
            self.forward_ret_dict.update(targets_dict)   # 此时记录gt信息以及预测信息,来进行后续的loss计算

        # 测试过程
        if not self.training or self.predict_boxes_when_training:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=data_dict['batch_size'],
                cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
            )
            data_dict['batch_cls_preds'] = batch_cls_preds
            data_dict['batch_box_preds'] = batch_box_preds
            data_dict['cls_preds_normalized'] = False

        return data_dict    # 返回更新后的data_dict
  • 各モジュールによる処理後のアルゴリズムの流れの違いについては、以下のとおりです。
# 功能:基于Detector3DTemplate构建PointPillar算法结构
class PointPillar(Detector3DTemplate):
    def __init__(self, model_cfg, num_class, dataset):
        """
        Args:
            model_cfg:   yaml配置文件的MODEL部分
            num_class:   类别数目(kitti数据集一般用3个类别:'Car', 'Pedestrian', 'Cyclist')
            dataset:     训练数据集
        """
        super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)   # 初始化基类

        # 网络的各处理模块已经存储在self中(vfe / map_to_bev / backbone_2d ...)
        self.module_list = self.build_networks()    # 真正构建模型的处理函数,Detector3DTemplate的子函数

    def forward(self, batch_dict):
        # 各模块分别进行特征处理,更新batch_dict,然后将预测信息与gt信息保存在forward_ret_dict字典中来进行后续的损失计算
        for cur_module in self.module_list:
            batch_dict = cur_module(batch_dict)

        # 训练过程进行损失计算
        if self.training:
            loss, tb_dict, disp_dict = self.get_training_loss()    # 损失计算

            ret_dict = {
    
    
                'loss': loss
            }
            return ret_dict, tb_dict, disp_dict
            
        # 测试过程进行后处理返回预测结果
        else:
            pred_dicts, recall_dicts = self.post_processing(batch_dict)
            return pred_dicts, recall_dicts

    ......

これらのコア機能モジュールは、以下に個別に記録されます。


1. AnchorHeadTemplate.generate_predicted_boxes 部分

generate_predicted_boxes 関数の初期データ入力は次のとおりです。

ここに画像の説明を挿入

最後に特徴を辞書に保存します。

ここに画像の説明を挿入


2. Detector3DTemplate.post_processing部分

AnchorHeadTemplate.generate_predicted_boxes 関数の処理後に更新されるbatch_dict ディクショナリは、ここでは後処理関数の入力となります。このbatch_dictは、テストフェーズの後処理部分を含むモデル全体の順伝播プロセスを実行すると言えます。このbatch_dictはトレーニング中には使用されません。


3. KittiDataset.generate_prediction_dicts部分

この部分は kitti データの予測処理モジュールを実行します


4. KittiDataset.評価部分

この部分は kitti データの特定の検証モジュールです


おすすめ

転載: blog.csdn.net/weixin_44751294/article/details/130556374