VoxelNext,全稀疏的3D目标检测网络

GitHub - dvlab-research/VoxelNeXt: VoxelNeXt: Fully Sparse VoxelNet for 3D Object Detection and Tracking (CVPR 2023)

https://arxiv.org/abs/2303.11301

摘要

当前3D目标检测模型,在检测部分都是沿用2D的方法,在dense的特征图上,通过预设的anchor或者center来预测3D的框,本文的创新是利用点云的稀疏的特性,在通过spconv提取特征后,不转化到dense的特征图,直接在稀疏的特征上进行3D框的预测。经验证,在常用的公开数据集上都取得了很好的效果。

1. 介绍

以常用的centerpoint模型为例,其中有,sparse to dense,虽然能有效工作,但是带来如下问题:计算资源的浪费、流程复杂、需要nms后处理。

 本文提出的方法,省去了center的anchor、sparse to dense、rpn、nms等步骤,直接而且是只在稀疏的特征位置上进行预测。

VoxelNext和Centerpoint,flops的优化。 

VoxelNext方法,相对centerpoint,FSD,在不同检测范围下的latency的对比,VoxelNext对长距离目标检测很友好。

2. 相关工作

       Lidar Detectors

        目前3D的检测器,通常都是参照2D的检测器,比如rcnn系列,比如centerpoint系列,虽然3D点云相对于2D数据本身是稀疏的,但是目前的检测器都还是在dense的特征图上进行预测的。本文进行一个变化点,直接在稀疏的特征上进行目标预测。

         Sparse Detectors

           分析了一些sparse的detectors,比如waymo的RSN,先在range image上segmentation提取前景点,然后在稀疏的前景点上进行目标检测;SWFormer,FSD都是一些稀疏检测的尝试,但是过程都偏复杂,本文用常用的稀疏卷积,尽量简化过程。
pillarnet

RSN

        Sparse Convlution Network

          因为稀疏卷积的高效性,现在是3D网络backbone的主流方法。但是一般都不直接用于检测头。目前有一些尝试优化,比如用transformer增加感受野,但是本文是通过额外的下采样来实现感受野的增加。

        3D Object Tracking        

          常见的是用kalman filter对结果进行跟踪,也有centertrack那样的直接预测速度,本文也利用了voxel的query来进行关联,有效的预测了物体中心的偏差。

3. Fully Sparse Voxel-based Network

        voxelnext网络结构示意图:

3.1 backbone adaptation

additional down sampling

在原先的下采样基础上,{1,2,4,8},{F 1 , F 2 , F 3 , F 4 },继续下采样{16,32},{F5,F6},然后把F4,F5,F6的spatial resolution align到F4,然后生成Fc。

 F是稀疏的特征,P是3D的坐标值。Fc就是F4,F5,F6的特征叠加。同时更新P5,P6到P4的尺寸。

x_conv5 = self.conv5(x_conv4)
x_conv6 = self.conv6(x_conv5)

x_conv5.indices[:, 1:] *= 2
x_conv6.indices[:, 1:] *= 4
x_conv4 = x_conv4.replace_feature(torch.cat([x_conv4.features, x_conv5.features, x_conv6.features]))
x_conv4.indices = torch.cat([x_conv4.indices, x_conv5.indices, x_conv6.indices])

sparse height compression

常规的做法,稀疏变dense,然后z维度加到channel维度。

这里,把稀疏的特征直接放置在bev平面,然后add求和。非常高效。

def bev_out(self, x_conv):
        features_cat = x_conv.features
        indices_cat = x_conv.indices[:, [0, 2, 3]]
        spatial_shape = x_conv.spatial_shape[1:]

        indices_unique, _inv = torch.unique(indices_cat, dim=0, return_inverse=True)
        features_unique = features_cat.new_zeros((indices_unique.shape[0], features_cat.shape[1]))
        features_unique.index_add_(0, _inv, features_cat)

        x_out = spconv.SparseConvTensor(
            features=features_unique,
            indices=indices_unique,
            spatial_shape=spatial_shape,
            batch_size=x_conv.batch_size
        )
        return x_out

spatially voxel prunning

在下采样的过程中,对不重要的背景特征进行prune。既可以突出前景,也可以提高运算效率。

3.2 sparse head

        1. class head

预测,NxF => NxK

target,靠近gt box中心最近的voxel,是positive sample。

loss, focal loss

inference, 使用sparse max pooling. voxel本身够稀疏,只在非空的位置操作。如果本身物体离的很近怎么办?

 实验发现,query voxel,并不一定在box中心,甚至不一定在box框内。

        2. regression head

positive的voxel筛选, N->n

预测,nxF => nx2(dx,dy), nx1(z), nx3(w,h,l), nx2(cos,sin)

loss, l1 loss

相关代码:

前向的网络结构,整体结构和之前的cenerhead比,卷积从2d的conv,变成2d的subMconv。hm还叫hm。

class SeparateHead(nn.Module):
    def __init__(self, input_channels, sep_head_dict, kernel_size, init_bias=-2.19, use_bias=False):
        super().__init__()
        self.sep_head_dict = sep_head_dict

        for cur_name in self.sep_head_dict:
            output_channels = self.sep_head_dict[cur_name]['out_channels']
            num_conv = self.sep_head_dict[cur_name]['num_conv']

            fc_list = []
            for k in range(num_conv - 1):
                fc_list.append(spconv.SparseSequential(
                    spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name),
                    nn.BatchNorm1d(input_channels),
                    nn.ReLU()
                ))
            fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out'))
            fc = nn.Sequential(*fc_list)
            if 'hm' in cur_name:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, spconv.SubMConv2d):
                        kaiming_normal_(m.weight.data)
                        if hasattr(m, "bias") and m.bias is not None:
                            nn.init.constant_(m.bias, 0)

            self.__setattr__(cur_name, fc)

    def forward(self, x):
        ret_dict = {}
        for cur_name in self.sep_head_dict:
            ret_dict[cur_name] = self.__getattr__(cur_name)(x).features

        return ret_dict

目标编码,之前是dense的hm,以及gt对应的编码后的target boxes

现在是稀疏的hm,以及对应编码后的target boxes。

def assign_target_of_single_head(
            self, num_classes, gt_boxes, num_voxels, spatial_indices, spatial_shape, feature_map_stride, num_max_objs=500,
            gaussian_overlap=0.1, min_radius=2
    ):
        """
        Args:
            gt_boxes: (N, 8)
            feature_map_size: (2), [x, y]

        Returns:

        """
        heatmap = gt_boxes.new_zeros(num_classes, num_voxels)

        ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1))
        inds = gt_boxes.new_zeros(num_max_objs).long()
        mask = gt_boxes.new_zeros(num_max_objs).long()

        x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2]
        coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride
        coord_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / feature_map_stride

        coord_x = torch.clamp(coord_x, min=0, max=spatial_shape[1] - 0.5)  # bugfixed: 1e-6 does not work for center.int()
        coord_y = torch.clamp(coord_y, min=0, max=spatial_shape[0] - 0.5)  #

        center = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1)
        center_int = center.int()
        center_int_float = center_int.float()

        dx, dy, dz = gt_boxes[:, 3], gt_boxes[:, 4], gt_boxes[:, 5]
        dx = dx / self.voxel_size[0] / feature_map_stride
        dy = dy / self.voxel_size[1] / feature_map_stride

        radius = centernet_utils.gaussian_radius(dx, dy, min_overlap=gaussian_overlap)
        radius = torch.clamp_min(radius.int(), min=min_radius)

        for k in range(min(num_max_objs, gt_boxes.shape[0])):
            if dx[k] <= 0 or dy[k] <= 0:
                continue

            if not (0 <= center_int[k][0] <= spatial_shape[1] and 0 <= center_int[k][1] <= spatial_shape[0]):
                continue

            cur_class_id = (gt_boxes[k, -1] - 1).long()
            
            # 距离最近的voxel选为query voxel
            # inds也更新为此voxel的顺序
            distance = self.distance(spatial_indices, center[k])
            inds[k] = distance.argmin()
            mask[k] = 1
            
            
            # 在稀疏的hm上,进行hm的绘制   
            if 'gt_center' in self.gaussian_type:
                centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], distance, radius[k].item() * self.gaussian_ratio)

            if 'nearst' in self.gaussian_type:
                centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], self.distance(spatial_indices, spatial_indices[inds[k]]), radius[k].item() * self.gaussian_ratio)
            
            # △x,△y,是center和代理voxel的spatial inds的offset
            ret_boxes[k, 0:2] = center[k] - spatial_indices[inds[k]][:2]
            ret_boxes[k, 2] = z[k]
            ret_boxes[k, 3:6] = gt_boxes[k, 3:6].log()
            ret_boxes[k, 6] = torch.cos(gt_boxes[k, 6])
            ret_boxes[k, 7] = torch.sin(gt_boxes[k, 6])
            if gt_boxes.shape[1] > 8:
                ret_boxes[k, 8:] = gt_boxes[k, 7:-1]

        return heatmap, ret_boxes, inds, mask

hm以及box的decode

def decode_bbox_from_voxels_nuscenes(batch_size, indices, obj, rot_cos, rot_sin,
                            center, center_z, dim, vel=None, iou=None, point_cloud_range=None, voxel_size=None, voxels_3d=None,
                            feature_map_stride=None, K=100, score_thresh=None, post_center_limit_range=None, add_features=None):
    batch_idx = indices[:, 0]
    spatial_indices = indices[:, 1:]
    scores, inds, class_ids = _topk_1d(None, batch_size, batch_idx, obj, K=K, nuscenes=True)

    center = gather_feat_idx(center, inds, batch_size, batch_idx)
    rot_sin = gather_feat_idx(rot_sin, inds, batch_size, batch_idx)
    rot_cos = gather_feat_idx(rot_cos, inds, batch_size, batch_idx)
    center_z = gather_feat_idx(center_z, inds, batch_size, batch_idx)
    dim = gather_feat_idx(dim, inds, batch_size, batch_idx)
    spatial_indices = gather_feat_idx(spatial_indices, inds, batch_size, batch_idx)

    if not add_features is None:
        add_features = [gather_feat_idx(add_feature, inds, batch_size, batch_idx) for add_feature in add_features]

    if not isinstance(feature_map_stride, int):
        feature_map_stride = gather_feat_idx(feature_map_stride.unsqueeze(-1), inds, batch_size, batch_idx)

    angle = torch.atan2(rot_sin, rot_cos)
    xs = (spatial_indices[:, :, -1:] + center[:, :, 0:1]) * feature_map_stride * voxel_size[0] + point_cloud_range[0]
    ys = (spatial_indices[:, :, -2:-1] + center[:, :, 1:2]) * feature_map_stride * voxel_size[1] + point_cloud_range[1]
    #zs = (spatial_indices[:, :, 0:1]) * feature_map_stride * voxel_size[2] + point_cloud_range[2] + center_z

    box_part_list = [xs, ys, center_z, dim, angle]

    if not vel is None:
        vel = gather_feat_idx(vel, inds, batch_size, batch_idx)
        box_part_list.append(vel)

    if not iou is None:
        iou = gather_feat_idx(iou, inds, batch_size, batch_idx)
        iou = torch.clamp(iou, min=0, max=1.)

    final_box_preds = torch.cat((box_part_list), dim=-1)
    final_scores = scores.view(batch_size, K)
    final_class_ids = class_ids.view(batch_size, K)
    if not add_features is None:
        add_features = [add_feature.view(batch_size, K, add_feature.shape[-1]) for add_feature in add_features]

    assert post_center_limit_range is not None
    mask = (final_box_preds[..., :3] >= post_center_limit_range[:3]).all(2)
    mask &= (final_box_preds[..., :3] <= post_center_limit_range[3:]).all(2)

    if score_thresh is not None:
        mask &= (final_scores > score_thresh)

    ret_pred_dicts = []
    for k in range(batch_size):
        cur_mask = mask[k]
        cur_boxes = final_box_preds[k, cur_mask]
        cur_scores = final_scores[k, cur_mask]
        cur_labels = final_class_ids[k, cur_mask]
        cur_add_features = [add_feature[k, cur_mask] for add_feature in add_features] if not add_features is None else None
        cur_iou = iou[k, cur_mask] if not iou is None else None

        ret_pred_dicts.append({
            'pred_boxes': cur_boxes,
            'pred_scores': cur_scores,
            'pred_labels': cur_labels,
            'pred_ious': cur_iou,
            'add_features': cur_add_features,
        })
    return ret_pred_dicts

3.3 object tracking

voxel association

   query voxel作为center的代理,用l2 distance去关联query voxel。

        

猜你喜欢

转载自blog.csdn.net/huang_victor/article/details/130065986
今日推荐