Pytorch模型转ONNX模型并使用ONNXRuntime运行

Pytorch模型转ONNX模型

import torch
import torch.nn as nn
from backbone import OXI_Net


model = OXI_Net()  # 实例化模型对象
model.load_state_dict(torch.load('./cnn_model_50.pth'))  # 加载模型
model.eval()  # 推理模式

input_names = ['image']  # 输入名称
output_names = ['label']  # 输出名称

x = torch.randn(1, 3, 224, 224)  # 模型输入的形状和类型

torch.onnx.export(model, x, './cnn_model_50.onnx', input_names=input_names, output_names=output_names)  # 导出onnx模型

print("ONNX模型导出成功!")

使用ONNXRuntime运行ONNX模型

import onnxruntime
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional
from PIL import Image


# 加载ONNX模型
onnx_model_path = "./cnn_model_50.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)

# 设置输入名称和输出名称
input_name = session.get_inputs()[0].name  # 获取输入的名称
output_name = session.get_outputs()[0].name  # 获取输出的名称

# 读取和预处理图像
image = Image.open('./idcard.bmp')
image = functional.crop(image, left=0, top=0, width=648, height=648)
transform = transforms.Compose([transforms.Resize(224),
                                transforms.RandomRotation(10),
                                transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

input_data = transform(image)
input_data = input_data.unsqueeze(0)  # 添加批处理维度
input_data = np.array(input_data)  # 转成numpy数组

# 使用ONNXRuntime进行推断
output = session.run([output_name], {
    
    input_name: input_data})

# 处理推断结果
predicted_result = output[0]
predicted_class = np.argmax(predicted_result)

print("预测的结果为:", predicted_result)
print("预测的类别索引为:", predicted_class)

猜你喜欢

转载自blog.csdn.net/weixin_48158964/article/details/132468897