Convert the PointNet-trained pth model to an onnx model, and infer and deploy it based on the C++ onnxruntime framework

1.pth model to onnx model

1.1 Place the pth model to be converted in the models directory

Insert image description here

1.2 Create a new model conversion script file export.py

from models import pointnet2_cls_ssg
import os
import sys
import torch
import argparse

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models', 'log'))

def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('Testing')
    parser.add_argument('--use_cpu', action='store_true', default=True, help='use cpu mode')
    parser.add_argument('--model', default='pointnet2_cls_ssg',help='model name [default: pointnet_cls]')  # pointnet2_cls_ssg/pointnet_cls
    parser.add_argument('--num_category', default=3, type=int, choices=[2, 3, 10, 40],help='training on ModelNet10/40')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    return parser.parse_args()

args = parse_args()
point_num = args.num_point
class_num = args.num_category
normal_channel = args.use_normals

model = pointnet2_cls_ssg.get_model(class_num, normal_channel)
if not args.use_cpu:
    model = model.cuda()
model.eval()
if not args.use_cpu:
    checkpoint = torch.load('best_model.pth')
else:
    checkpoint = torch.load('best_model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

x = (torch.rand(1, 6, point_num) if normal_channel else torch.rand(1, 3, point_num))
if not args.use_cpu:
    x = x.cuda()

traced_script_module = torch.jit.trace(model, x)
export_onnx_file = "cls.onnx"
torch.onnx.export(traced_script_module, x, export_onnx_file, opset_version=11)
# traced_script_module.save("cls.pt")

1.3 Modify pointnet2_utils.py

In order for the torch.onnx.export(traced_script_module, x, export_onnx_file, opset_version=11) function to execute normally, the pointnet2_utils.py file needs to be modified. The modified code is as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np


def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()


def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc


def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    # view_shape[1:] = [1] * (len(view_shape) - 1)
    new_view_shape = [view_shape[0]] + [1] * (len(view_shape) - 1)
    view_shape = new_view_shape

    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def farthest_point_sample(xyz, npoint: int):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device

Guess you like

Origin blog.csdn.net/weixin_44330367/article/details/132683983