【代码】pytorch转onnx及onnx的推理

工作上经常用到,要翻之前的代码比较麻烦,那就记录一下吧!

1.pytorch转onnx

import torch
from nets.vit import vit

model = vit(input_shape=[250, 200], num_classes = 3, pretrained = False)
weight = torch.load("./best_epoch_weights.pth", map_location = 'cpu')
model.load_state_dict(weight)

 
batch_size = 1  
input_shape = (3, 250, 200)   #输入数据,改成自己的输入shape 
model.eval()
 
x = torch.randn(batch_size, *input_shape)
export_onnx_file = "test.onnx"	# 输出的ONNX文件名
torch.onnx.export(model,
                    x,
                    export_onnx_file,
                    opset_version=11,
                    do_constant_folding=True,
                    verbose=True,	# 打印网络
                    input_names=["input"],	
                    output_names=["output"],	
                    dynamic_axes={
    
    "input":{
    
    0: 'batch',2:'batch',3:'batch'},
                                   "output":{
    
    0: 'batch',2:'batch',3:'batch'}})

简化版本

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
def model_converter():
    model = torch.load('test.pth').to(device)  # 这里保存的是完整模型
    model.eval()
 
    dummy_input = torch.randn(1, 3, 224, 224, device=device) # shape改成自己的
    input_names = ['input']
    output_names = ['output']
    torch.onnx.export(model, dummy_input, 'kkktest.onnx', 
                      export_params=True, 
                      verbose=True, 
                      input_names=input_names, 
                      output_names=output_names)
# 转换
model_converter()

2.onnx推理代码

其中PIL方法和cv2方法都可以用,选择一种即可~

import os
import cv2
import argparse
import numpy as np
from PIL import Image
import onnxruntime
import numbers
import torchvision.transforms as transforms

def preprocess(mat):
    
    #-----------------------------------------------------------------------------#
    #                                   PIL处理方法
    #-----------------------------------------------------------------------------#
    mean_nom = [0.485, 0.456, 0.406]
    std_norm = [0.229, 0.229, 0.229]

    transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=mean_nom , std=std_norm)])
    img = transform_test(mat)
    
    # --
    image_data = np.transpose(np.expand_dims(np.array(img, np.float32), 0), (0, 1, 2, 3))

    #-----------------------------------------------------------------------------#
    #                                   cv2处理方法
    #-----------------------------------------------------------------------------#
    # mean = np.array([[[0.485, 0.456, 0.406]]])      # 训练的时候用来mean和std
    # std = np.array([[[0.229, 0.224, 0.225]]])
 
    # img = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB)
    # img = cv2.resize(img, (224, 224))                 # (224, 224, 3)
 
    # image = img.astype(np.float32)/255.0
    # image = (image - mean)/ std

    # image = image.transpose((2, 0, 1))              # (3, 224, 224)
    # image = image[np.newaxis,:,:,:]                 # (1, 3, 224, 224)
    # image_data = np.array(image, dtype=np.float32)

    # 返回处理后图像
    return image_data

# 展示top-5结果
def show_outputs(output):
    # output = outputs[0][0]
    output_sorted = sorted(output, reverse=True)
    top5_str = 'resnet50v2\n-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

if __name__ == '__main__':
    img_path = '/home/kk/mnt/图像.jpg'
    
    # PIL读图  方法
    img = Image.open(img_path)
    
    # # cv2读图  方法
    # img = cv2.imread(img_path)
    
    # 图像预处理
    imgdata = preprocess(img)
    
    sess = onnxruntime.InferenceSession('/home/kk/mnt/kkktest.onnx')
    input_name = sess.get_inputs()[0].name  
    output_name = sess.get_outputs()[0].name
 
    pred_onnx = sess.run([output_name], {
    
    input_name: imgdata})
    res = np.array(pred_onnx[0][0])
   
    res = softmax(res)
    
    print("outputs:")
    print(res)
    show_outputs(res)

note:
PIL和cv2读图区别有点忘记了,后续总结一下。

猜你喜欢

转载自blog.csdn.net/weixin_45392674/article/details/127671191