[Code] pytorch to onnx and onnx reasoning

It is often used at work, and it is troublesome to flip through the previous code, so record it!

1.pytorch to onnx

import torch
from nets.vit import vit

model = vit(input_shape=[250, 200], num_classes = 3, pretrained = False)
weight = torch.load("./best_epoch_weights.pth", map_location = 'cpu')
model.load_state_dict(weight)

 
batch_size = 1  
input_shape = (3, 250, 200)   #输入数据,改成自己的输入shape 
model.eval()
 
x = torch.randn(batch_size, *input_shape)
export_onnx_file = "test.onnx"	# 输出的ONNX文件名
torch.onnx.export(model,
                    x,
                    export_onnx_file,
                    opset_version=11,
                    do_constant_folding=True,
                    verbose=True,	# 打印网络
                    input_names=["input"],	
                    output_names=["output"],	
                    dynamic_axes={
    
    "input":{
    
    0: 'batch',2:'batch',3:'batch'},
                                   "output":{
    
    0: 'batch',2:'batch',3:'batch'}})

Simplified version

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
def model_converter():
    model = torch.load('test.pth').to(device)  # 这里保存的是完整模型
    model.eval()
 
    dummy_input = torch.randn(1, 3, 224, 224, device=device) # shape改成自己的
    input_names = ['input']
    output_names = ['output']
    torch.onnx.export(model, dummy_input, 'kkktest.onnx', 
                      export_params=True, 
                      verbose=True, 
                      input_names=input_names, 
                      output_names=output_names)
# 转换
model_converter()

2.onnx reasoning code

Among them, both the PIL method and the cv2 method can be used, just choose one~

import os
import cv2
import argparse
import numpy as np
from PIL import Image
import onnxruntime
import numbers
import torchvision.transforms as transforms

def preprocess(mat):
    
    #-----------------------------------------------------------------------------#
    #                                   PIL处理方法
    #-----------------------------------------------------------------------------#
    mean_nom = [0.485, 0.456, 0.406]
    std_norm = [0.229, 0.229, 0.229]

    transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=mean_nom , std=std_norm)])
    img = transform_test(mat)
    
    # --
    image_data = np.transpose(np.expand_dims(np.array(img, np.float32), 0), (0, 1, 2, 3))

    #-----------------------------------------------------------------------------#
    #                                   cv2处理方法
    #-----------------------------------------------------------------------------#
    # mean = np.array([[[0.485, 0.456, 0.406]]])      # 训练的时候用来mean和std
    # std = np.array([[[0.229, 0.224, 0.225]]])
 
    # img = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB)
    # img = cv2.resize(img, (224, 224))                 # (224, 224, 3)
 
    # image = img.astype(np.float32)/255.0
    # image = (image - mean)/ std

    # image = image.transpose((2, 0, 1))              # (3, 224, 224)
    # image = image[np.newaxis,:,:,:]                 # (1, 3, 224, 224)
    # image_data = np.array(image, dtype=np.float32)

    # 返回处理后图像
    return image_data

# 展示top-5结果
def show_outputs(output):
    # output = outputs[0][0]
    output_sorted = sorted(output, reverse=True)
    top5_str = 'resnet50v2\n-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

if __name__ == '__main__':
    img_path = '/home/kk/mnt/图像.jpg'
    
    # PIL读图  方法
    img = Image.open(img_path)
    
    # # cv2读图  方法
    # img = cv2.imread(img_path)
    
    # 图像预处理
    imgdata = preprocess(img)
    
    sess = onnxruntime.InferenceSession('/home/kk/mnt/kkktest.onnx')
    input_name = sess.get_inputs()[0].name  
    output_name = sess.get_outputs()[0].name
 
    pred_onnx = sess.run([output_name], {
    
    input_name: imgdata})
    res = np.array(pred_onnx[0][0])
   
    res = softmax(res)
    
    print("outputs:")
    print(res)
    show_outputs(res)

Note:
The difference between PIL and cv2 reading pictures is a bit forgotten, and I will summarize it later.

Guess you like

Origin blog.csdn.net/weixin_45392674/article/details/127671191