简析SA-SSD在预处理训练评估的框架

1. 前言

作为一个小白,笔者认为,从实践角度去深入一个深度学习领域(比如3D目标检测),不可忽视两件最为基础的事情:(1)怎样处理数据集;(2)怎样评估实验结果。这篇博客分析cvpr2020论文SA-SSD: Structure Aware Single-stage 3D Object Detection from Point Cloud开源代码中的数据预处理,模型训练和结果评估这几块基础代码。

总结自己分析代码的两个小技巧:(1)从顶到下理解代码;(2)没有必要读懂全部的代码,只需要读懂需要读懂的部分即可。

2. 数据预处理

从代码的ReadMe获知数据处理使用tools/create_data.py文件。这个文件主要做三件事。

(1)调用函数create_kitti_info_file获取数据集中点云图像路径;

(2)调用函数create_reduced_point_cloud获取相机视场内的点云;

(3)调用函数create_groundtruth_database生成3D目标检测真值;

2.1 获取数据集中点云图像路径

这一节分析函数create_kitti_info_file。数据集分成三个数据子集,训练集(Train),验证集(Validation),和测试集(Test)。测试集没有真值,测试集的结果需要提交到KITTI上,得到3D目标各个类别检测精度。在KITTI上提交结果需要先注册KITTI账号,用学校或企业邮箱注册即可。

第一步:获取各个子集index

SA-SSD开源代码中下载imagesets.tar.gz,里面包含train.txtval.txttest.txt。这些文本中存放各个子数据集的index。函数_read_imageset_file会读取它们的index,并保存成list数据结构:

    train_img_ids = _read_imageset_file("./data/ImageSets/train.txt")
    val_img_ids = _read_imageset_file("./data/ImageSets/val.txt")
    test_img_ids = _read_imageset_file("./data/ImageSets/test.txt")

第二步:获取各个子集中点云和图像的路径

对各个子数据集调用函数get_kitti_image_info,以获取点云和图像的路径。函数_calculate_num_points_in_gt用来数相机视场范围内LiDAR点云的个数。点云总数会存放在字典型变量kitti_infos_train中。结果将存放在pkl类型文件。

对训练数据而言:

    kitti_infos_train = kitti.get_kitti_image_info(
        data_path,
        training=True,
        velodyne=True,
        calib=True,
        image_ids=train_img_ids,
        relative_path=relative_path)
    _calculate_num_points_in_gt(data_path, kitti_infos_train, relative_path)
    filename = save_path / 'kitti_infos_train.pkl'
    print(f"Kitti info train file is saved to {filename}")
    with open(filename, 'wb') as f:
        pickle.dump(kitti_infos_train, f)

上述这个过程适用于验证集和测试集的处理。把path换了就行。

简要分析函数get_kitti_image_info,它的输出image_info是一个字典型变量,包含'image_idx'pointcloud_num_featuresvelodyne_pathimg_pathimg_shapeannosimage_info还包含标定的参数,比如相机内参数,calib/P0calib/P4,雷达相机外参数,calib/R0_rectcalib/Tr_velo_to_cam

我重点交代annos,它也是一个字典型变量,跟目标检测标注相关。生成它的代码如下所示:

        if label_info:
        	# 找到目标检测标签label的路径
            label_path = get_label_path(idx, path, training, relative_path)
            if relative_path:
            	# 如果是相对路径就加上前缀,得到绝对路径
                label_path = str(root_path / label_path)
            # 读取目标检测标签
            annotations = get_label_anno(label_path)

        if annotations is not None:
            image_info['annos'] = annotations
            # 根据kitti官方指标(比如遮挡度),给目标检测标签添加难易度评价
            # annos["difficulty"] = 0(Easy), 1(Mid),2(Hard)
            add_difficulty_to_annos(image_info)

有必要了解函数get_label_anno,来看一下哈:

def get_label_anno(label_path):
    annotations = {}
    annotations.update({
        'name': [],
        'truncated': [],
        'occluded': [],
        'alpha': [],
        'bbox': [],
        'dimensions': [],
        'location': [],
        'rotation_y': []
    })
    with open(label_path, 'r') as f:
        lines = f.readlines()
    # if len(lines) == 0 or len(lines[0]) < 15:
    #     content = []
    # else:
    content = [line.strip().split(' ') for line in lines]
    num_objects = len([x[0] for x in content if x[0] != 'DontCare'])
    annotations['name'] = np.array([x[0] for x in content])
    num_gt = len(annotations['name'])
    annotations['truncated'] = np.array([float(x[1]) for x in content])
    annotations['occluded'] = np.array([int(float(x[2])) for x in content])
    annotations['alpha'] = np.array([float(x[3]) for x in content])
    annotations['bbox'] = np.array(
        [[float(info) for info in x[4:8]] for x in content]).reshape(-1, 4)
    # dimensions will convert hwl format to standard lhw(camera) format.
    annotations['dimensions'] = np.array(
        [[float(info) for info in x[8:11]] for x in content]).reshape(
            -1, 3)[:, [2, 0, 1]]
    annotations['location'] = np.array(
        [[float(info) for info in x[11:14]] for x in content]).reshape(-1, 3)
    annotations['rotation_y'] = np.array(
        [float(x[14]) for x in content]).reshape(-1)
    if len(content) != 0 and len(content[0]) == 16:  # have score
        annotations['score'] = np.array([float(x[15]) for x in content])
    else:
        annotations['score'] = np.zeros((annotations['bbox'].shape[0], ))
    index = list(range(num_objects)) + [-1] * (num_gt - num_objects)
    annotations['index'] = np.array(index, dtype=np.int32)
    annotations['group_ids'] = np.arange(num_gt, dtype=np.int32)
    return annotations

annos中,truncatedoccluded用于衡量该3d目标被遮挡的程度,用于给该3d目标被检测到的难易程度打分。bboxdimensions都表示3d目标的长宽高,它们的区别见代码中的注释。location表示3d框中心点的位置。rotation_y表示3d框中Yaw角度。score表示目标类别置信度。num_objects表示当前点云中有多少3d目标。name指3d目标类别。indexgroup_ids指该3d目标的索引和类索引。这些变量的英文解释如下所示:

在这里插入图片描述
图1:标注信息释义转自博客

函数_calculate_num_points_in_gt会根据雷达外参数把点云投在相机坐标系下,然后滤除相机视场外的点,计算剩下点云的个数,把结果更新到annos["num_points_in_gt"]中。

KITTI官方给出对Easy,Medium,Hard的定量解释。
Easy: Min. bounding box height: 40 Px, Max. occlusion level: Fully visible, Max. truncation: 15 %
Moderate: Min. bounding box height: 25 Px, Max. occlusion level: Partly occluded, Max. truncation: 30 %
Hard: Min. bounding box height: 25 Px, Max. occlusion level: Difficult to see, Max. truncation: 50 %

2.2 获取相机视场内的点云

这一节分析函数create_reduced_point_cloud

第一步:读各个子数据集在2.1节整理的pkl类型文件

    if train_info_path is None:
        train_info_path = pathlib.Path(data_path) / 'kitti_infos_train.pkl'
    if val_info_path is None:
        val_info_path = pathlib.Path(data_path) / 'kitti_infos_val.pkl'
    if test_info_path is None:
        test_info_path = pathlib.Path(data_path) / 'kitti_infos_test.pkl'

第二步:获取相机视场范围内的点云

	# 视场范围内的点云保存在velodyne_reduced中
    _create_reduced_point_cloud(data_path, train_info_path, save_path)
    _create_reduced_point_cloud(data_path, val_info_path, save_path)
    _create_reduced_point_cloud(data_path, test_info_path, save_path)

2.3 获取3D目标检测真值

简要分析函数create_groundtruth_database。这一段代码有点杂乱。大致是每一类别的3d框都收集起来。把3D框真值存在pkl文件中。它的具体用处看后续代码中怎么调用吧。

2.4 小结

学习3D框标注信息以及如何处理它。

3. 训练框架简介

SA-SSD代码实在mmdetection平台上开发的。所以它无论是训练还是做预测,都会按照mmdetection的一套流程。关于mmdetection各种api的介绍可以参考这篇非常不错的博客。也可以参考mmdetection官方文档这篇知乎帖子写的也很不错。

如果要训练一个网络,会执行下面代码:

python3 train.py ../configs/car_cfg.py

其中car_cfg.py是配置文件,用于保存模型超参数,训练超参数,和测试超参数,以及学习策略配置等等。

train.py文件中,用下述代码生成训练数据集:

	# 生成训练数据集
    train_dataset = get_dataset(cfg.data.train)

	# 开始训练
	# mmdetection的一行代码就训练,然而这种简洁的操作多少有些惊艳
	# 后来发现,损失函数(smooth l1),优化器(SGD),学习率,训练批次等都在car_cfg.py定义好啦
    train_detector(
        model,
        train_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

看看cfg内容(一堆参数):

# 虽然参数很多,但是看到超参数的名字不难猜到它的意思
# dataset settings
# model training and testing settings
train_cfg = dict(
    rpn=dict(
        assigner=dict(
            pos_iou_thr=0.6,
            neg_iou_thr=0.45,
            min_pos_iou=0.45, # this one is to limit the force assignment
            ignore_iof_thr=-1,
            similarity_fn ='NearestIouSimilarity'
        ),
        nms=dict(
            nms_across_levels=False,
            nms_pre=2000,
            nms_post=2000,
            nms_thr=0.7,
            min_bbox_size=0
        ),
        allowed_border=0,
        pos_weight=-1,
        smoothl1_beta=1 / 9.0,
        debug=False),
    extra=dict(
        assigner=dict(
            pos_iou_thr=0.7,
            neg_iou_thr=0.7,
            min_pos_iou=0.7,
            ignore_iof_thr=-1,
            similarity_fn ='RotateIou3dSimilarity'
        )
    )
)

dataset_type = 'KittiLiDAR'
data_root = '/home/billyhe/data/KITTI/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        root=data_root + 'training/',
        ann_file=data_root + 'ImageSets/train.txt',
        img_prefix=None,
        img_scale=(1242, 375),
        img_norm_cfg=img_norm_cfg,
        size_divisor=32,
        flip_ratio=0.5,
        with_mask=False,
        with_label=True,
        with_point=True,
        class_names = ['Car', 'Van'],
        augmentor=dict(
            type='PointAugmentor',
            root_path=data_root,
            info_path=data_root + 'kitti_dbinfos_train.pkl',
            sample_classes=['Car'],
            min_num_points=5,
            sample_max_num=15,
            removed_difficulties=[-1],
            global_rot_range=[-0.78539816, 0.78539816],
            gt_rot_range=[-0.78539816, 0.78539816],
            center_noise_std=[1., 1., .5],
            scale_range=[0.95, 1.05]
        ),
        generator=dict(
            type='VoxelGenerator',
            voxel_size=[0.05, 0.05, 0.1],
            point_cloud_range=[0, -40., -3., 70.4, 40., 1.],
            max_num_points=5,
            max_voxels=20000
        ),
        anchor_generator=dict(
            type='AnchorGeneratorStride',
            sizes=[1.6, 3.9, 1.56],
            anchor_strides=[0.4, 0.4, 1.0],
            anchor_offsets=[0.2, -39.8, -1.78],
            rotations=[0, 1.57],
        ),
        anchor_area_threshold=1,
        out_size_factor=8,
        test_mode=False),

    # 做验证的超参数和训练超参数是一样的,就不放出来了。
    # 在val中,test_mode=True
    # val=dict(...)
)

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.001)
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
# learning policy
lr_config = dict(
    policy='cosine',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3,
)

checkpoint_config = dict(interval=5)
log_config = dict(
    interval=20,
    hooks=[
        dict(type='TextLoggerHook'),
    ])
total_epochs = 50
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = '../saved_model_vehicle'
load_from = None
resume_from = None
workflow = [('train', 1)]

然后瞧瞧函数get_dataset,它的核心操作是调用函数obj_from_dict(大概是根据输入需求写data_info,然后从datasets读出dset,添加至dsetsdsets是输出的训练数据):

    dsets = []
    for i in range(num_dset):
        data_info = copy.deepcopy(data_cfg)
        data_info['ann_file'] = ann_files[i]
        data_info['proposal_file'] = proposal_files[i]
        data_info['img_prefix'] = img_prefixes[i]
        if generator is not None:
            data_info['generator'] = generator
        if anchor_generator is not None:
            data_info['anchor_generator'] = anchor_generator
        if augmentor is not None:
            data_info['augmentor'] = augmentor
        if target_encoder is not None:
            data_info['target_encoder'] = target_encoder
        # 核心操作
        dset = obj_from_dict(data_info, datasets)
        dsets.append(dset)
    if len(dsets) > 1:
        dset = ConcatDataset(dsets)
    else:
        dset = dsets[0]
    return dset

而函数train_detector是一套模板化流程,其中调用了torchDataLoader。更为具体的分析可以参考mmdetection说明博客

4. 结果评估

作为小白,处理一个完整的数据集,不仅仅需要对数据做预处理,还需要做结果评估(Evaluation),即计算预测值和真值间的误差。3D目标检测的误差计算稍微复杂。我们看看SA-SSD是怎样做处理的。

4.1 3D目标检测评估指标

怎样衡量网络预测的3D框和对应真值3D框之间的差异呢?对于KITTI数据集来说,拿车类目标做个例子,如果预测3D框和真值3D框之间重叠的部分占真值3D框的70%以上(70%是官方规定的),那么就可以认为预测3D框是准确的。如果10个预测结果中,有7个是准确的,那么车类目标识别率是70%。3D目标识别率简记为 A P AP ,或者 3 D A P 3D AP ,或者 A P 3 D AP_{3D}

使用IoU(Intersection-over-Union)表示预测3D框和真值3D框之间重叠的部分占真值3D框的比重。

每一种类别,KITTI官方规定的百分百阈值都不一样,可以参考下面原话:

For cars we require an 3D bounding box overlap of 70%, while for pedestrians and cyclists we require a 3D bounding box overlap of 50%.

把不同类别的目标识别结果平均起来,可以得到平均目标识别结果 m A P mAP

对于同一种类别,不断调整百分百阈值(称之为召回率,Recall),可以得到不同阈值下的 A P AP 值(称之为精度,Precision),进而会形成一条关于阈值百分比和 A P AP 的曲线,即召回率精度曲线(Recall-Precision Curve)。曲线覆盖面积将反应算法实际性能。曲线覆盖面积越趋近于1,算法性能越好。

除了使用 3 D A P 3D AP ,还会检测2D目标的精度 2 D A P 2D AP ,还会检测在 B E V BEV 视图下的目标精度 B E V A P BEV AP 。这些 A P AP 的定义都差不多。

4.2 评估流程

从代码的ReadMe获知结果评估使用tools/test.py文件。需要如下操作:

python3 test.py ../configs/car_cfg.py ../saved_model_vehicle/epoch_50.pth

其中car_cfg.py是配置文件,用于保存模型超参数,训练超参数,和测试超参数,以及学习策略配置等等。pth文件是模型训练好的参数。

评估流程代码如下。把Train那一块流程搞懂,这一块流程也很相似,比如常见函数get_datasetbuild_detector(加载SA-SSD网络,具体代码会在下一篇博客分析)。函数get_official_eval_result用来评估目标检测结果,输出 3 D A P 3D AP 2 D A P 2D AP ,和 B E V A P BEV AP

    dataset = utils.get_dataset(cfg.data.val)
    class_names = cfg.data.val.class_names
    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
        load_checkpoint(model, args.checkpoint)
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            1,
            cfg.data.workers_per_gpu,
            num_gpus=1,
            #collate_fn= cfg.data.collate_fn,
            shuffle=False,
            dist=False)
        # 一口气输出全部测试集的预测结果
        outputs = single_test(model, data_loader, args.out, class_names)
    else:
        NotImplementedError
    # kitti evaluation
    gt_annos = kitti.get_label_annos(dataset.label_prefix, dataset.sample_ids)
    # 计算class_names类别下2D AP,3D Ap和BEV AP
    result = get_official_eval_result(gt_annos, outputs, current_classes=class_names)

5. 结束语

因为是小白,不懂3D目标检测,从白天看到黑夜,把整体框架看明白了。初次接触mmdetection有点难受。但是弄懂了,就觉得这个框架对调参数和搭建网络都非常友好。mmdetection是面向目标检测的代码集成库,还需要深入学习。在下一篇博客中,我将分析SA-SSD的网络细节。有时间我会分析一下mmdetection。不知道时间是否充裕,就不立flag啦。

发布了59 篇原创文章 · 获赞 36 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_39732684/article/details/105111565