Pytorch模型转ONNX模型
import torch
import torch.nn as nn
from backbone import OXI_Net
model = OXI_Net()
model.load_state_dict(torch.load('./cnn_model_50.pth'))
model.eval()
input_names = ['image']
output_names = ['label']
x = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, x, './cnn_model_50.onnx', input_names=input_names, output_names=output_names)
print("ONNX模型导出成功!")
使用ONNXRuntime运行ONNX模型
import onnxruntime
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional
from PIL import Image
onnx_model_path = "./cnn_model_50.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
image = Image.open('./idcard.bmp')
image = functional.crop(image, left=0, top=0, width=648, height=648)
transform = transforms.Compose([transforms.Resize(224),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
input_data = transform(image)
input_data = input_data.unsqueeze(0)
input_data = np.array(input_data)
output = session.run([output_name], {
input_name: input_data})
predicted_result = output[0]
predicted_class = np.argmax(predicted_result)
print("预测的结果为:", predicted_result)
print("预测的类别索引为:", predicted_class)