【代码记录】pytorch推理及与onnx推理精度对比

1.pytorch推理

import cv2
import sys
import numpy as np
import torch, os
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

def inference(mat):
    # preprocess
   
    #预处理要保持一致
    mean_nom = [0.485, 0.456, 0.406]
    std_norm = [0.229, 0.229, 0.229]

    transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=mean_nom , std=std_norm)])
    img = transform_test(mat)
    image_data = np.transpose(np.expand_dims(np.array(img, np.float32), 0), (0, 1, 2, 3))
    # ---
    device = torch.device('cpu')
    net = models.resnet18(pretrained=True)
    net.eval()
    with torch.no_grad():
        photo = torch.from_numpy(image_data).to(device)
        res = softmax(np.array(net(photo)[0]))
        return res

def show_outputs(output):
    output_sorted = sorted(output, reverse=True)
    top5_str = '\n-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

if __name__ == '__main__':

    img_path = "./xxxx.jpg"

    
    # PIL读取
    img = Image.open(img_path)
    res = inference(img)

    print(res)

    show_outputs(res)

2.精度对比

from scipy.spatial.distance import cosine, euclidean
 #-------------------------------------------------------------------#
 #                      计算平均误差&余弦距离                          #
 #-------------------------------------------------------------------#
 mean_error = (onnx_res - pt_res).mean() # 我这里对比的是onnx及pytorch的
 cos_sim = 1- cosine(onnx_res, pt_res)
 print('the mean error: {}'.format(mean_error))
 print('the cosine similarity: {}'.format(cos_sim))

猜你喜欢

转载自blog.csdn.net/weixin_45392674/article/details/127671658