模型部署——CenterPoint转ONNX(自定义onnx算子)

CenterPoint基于OpenPcDet导出一个完整的ONNX,并用TensorRT推理,部署几个难点如下:

1.计算pillar中每个点相对几何中心的偏移,取下标方式进行计算是的整个计算图变得复杂,同时这种赋值方式导致运行在pytorch为浅拷贝,而在一些推理后端上表现为深拷贝

  • 修改代码,使用矩阵切片代替原先的操作,使导出的模型在推理后端上的行为结果和pytorch一致,并简化计算图,同时,计算网格坐标也需要修改,修改代码如下:
          # points_xyz = points[:, [0, 1, 2]].contiguous() 
          points_xyz = points[..., :3] 
          points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
  
          points_mean = scatter_mean(points_xyz,unq_inv)
          # # 每个点相对voxel质心的偏移
          f_cluster = points_xyz - points_mean[unq_inv, :] # torch.Size([1067877, 3])
          f_center = torch.zeros_like(points_xyz).to()
          # 每个点相对几何中心的偏移
          # f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset)
          # f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset)
          # f_center[:, 2] = points_xyz[:, 2] - self.z_offset
          device = points_xyz.device
          f_center = points_xyz - (points_coords * torch.tensor([self.voxel_x, self.voxel_y, self.voxel_z]).to(device) + torch.tensor([self.z_offset, self.y_offset, self.x_offset]).to(device))
  

2.torch_scatterscatter_meanscatter_max onnx不支持,需人为自定义onnx节点,后续并自定义tensorRTScatterMeanPluginScatterMaxPlugin算子

自定义onnx ScatterMax 算子如下,这里ScatterMax算子没有具体实现,仅为了增加相应的onnx节点,好导出onnx计算图,方便后续自定义实现TensorRT算子,实际上导出onnx并不能用onnxruntime来推理,这样做好处:我们可以只需要自定义实现TensorRT算子,对onnx增加相应节点就行,而不需要管具体的onnx算子实现。

class ScatterMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx,src,index):
    	  # 调unique仅为了输出对应的维度信息
        temp = torch.unique(src)
        out = torch.zeros((temp.shape[0],src.shape[1]),dtype=torch.float32,device=src.device)
        return out
    @staticmethod
    def symbolic(g,src,index):
        return g.op("xiaohu::ScatterMaxPlugin",src,index)

ScatterMeanPluginScatterBevPlugin节点和ScatterMaxPlugin节点定义方式是类似的

3.torch.stack算子 onnx不支持,导出onnx计算图很乱,将torch.stack和后续PointPillarScatter操作合并,一起定义为ScatterBevPlugin算子,自定义onnx节点和TensorRT算子来实现,ScatterBevPlugin实现功能和以下代码功能一致:

        voxel_coords = torch.stack((unq_coords // self.scale_xy, (unq_coords % self.scale_xy) // self.scale_y, unq_coords % self.scale_y,
                                   torch.zeros(unq_coords.shape[0]).to(unq_coords.device).int()), dim=1)
        # 将voxel_coords
        voxel_coords = voxel_coords[:, [0, 3, 2, 1]] # index,z,y,x

        pillars_feature = features.t()  # float32[64,pillar_num]
        spatial_feature = torch.zeros(64, 468 * 468,dtype=features.dtype, device=features.device)
        indices =  voxel_coords[:, 2] * 468 + voxel_coords[:, 3] #468 * y + x
        # indices = indices.type(torch.long)
        # tensors used as indices must be long, byte or bool tensors

        indices = indices.long()
        spatial_feature[:, indices] = pillars_feature
        spatial_feature = spatial_feature.view(1,64, 468, 468) # 对应onnx resahap

4.由于基于OpenPcDetCenterPoint用了动态体素化,计算体素信息调用torch.unique,而torch.unique算子 TensorRT 不支持,

torch.unique可以成功导出onnx,点击onnx 的unique节点,可以看出torch.unique输出有4个,而实际只有无重复的网格坐标unq_coords, 原始张量每个元素在处理后无重复数据中的索引unq_inv两个输出在后面用到了
在这里插入图片描述

在这里插入图片描述
onnx不支持torch all函数,实现TensorRT算子的本质用cuda/cpp实现Plugin::enqueue 函数,将下面python对应的一系列小型操作放在预处理实现,用cuda单独实现会更好点

        points_coords = torch.floor((points[:, [0,1,2]] - self.point_cloud_range[[0,1,2]]) / self.voxel_size[[0,1,2]]).int()
        # onnx不支持all
        # 如果张量中的所有元素为True,才返回True
        mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0,1]])).all(dim=1)

        mask = torch.rand(150000).bool()
        # 会调用onnx里的GatherND算子
        points = points[mask]
        points_coords = points_coords[mask]

        merge_coords = points_coords[:, 0] * self.scale_y + points_coords[:, 1] 
        # sorted:是否返回无重复张量按照数值进行排序,默认是升序排列,sorted并非表示降序
        # return_inverse:是否返回原始张量中每个元素在处理后的无重复张量中对应的索引
        # return_counts:统计原始张量中每个独立元素的个数
        # dim:值沿那个维度进行unique的处理
        unq_coords, unq_inv, _ = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0)

修改后,onnx输入有4个:原始点云points,无重复一维体素网格坐标unq_coords,原始张量中每个元素在处理后的无重复张量中对应的索引unq_inv,网格坐标coords

下面看CenterPoint转换出的onnx计算图:自定义onnx节点有 ScatterMaxPlugin,ScatterMeanPlugin,ScatterBevPlugin,用tensorRT实现就需要自定义ScatterMaxPlugin,ScatterMeanPlugin,ScatterBevPlugin 3个算子,后续会写下tenorRT自定义算子,并用cuda实现CenterPoint预处理和后处理,从而完成整个CenterPoint部署

onnx太小看不清,自定义onnx节点如下:
在这里插入图片描述

完整的onnx如下:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_42905141/article/details/127545123