[Запись кода] рассуждения pytorch и сравнение с точностью рассуждений onnx

1. рассуждения pytorch

import cv2
import sys
import numpy as np
import torch, os
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

def inference(mat):
    # preprocess
   
    #预处理要保持一致
    mean_nom = [0.485, 0.456, 0.406]
    std_norm = [0.229, 0.229, 0.229]

    transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=mean_nom , std=std_norm)])
    img = transform_test(mat)
    image_data = np.transpose(np.expand_dims(np.array(img, np.float32), 0), (0, 1, 2, 3))
    # ---
    device = torch.device('cpu')
    net = models.resnet18(pretrained=True)
    net.eval()
    with torch.no_grad():
        photo = torch.from_numpy(image_data).to(device)
        res = softmax(np.array(net(photo)[0]))
        return res

def show_outputs(output):
    output_sorted = sorted(output, reverse=True)
    top5_str = '\n-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

if __name__ == '__main__':

    img_path = "./xxxx.jpg"

    
    # PIL读取
    img = Image.open(img_path)
    res = inference(img)

    print(res)

    show_outputs(res)

2. Сравнение точности

from scipy.spatial.distance import cosine, euclidean
 #-------------------------------------------------------------------#
 #                      计算平均误差&余弦距离                          #
 #-------------------------------------------------------------------#
 mean_error = (onnx_res - pt_res).mean() # 我这里对比的是onnx及pytorch的
 cos_sim = 1- cosine(onnx_res, pt_res)
 print('the mean error: {}'.format(mean_error))
 print('the cosine similarity: {}'.format(cos_sim))

Guess you like

Origin blog.csdn.net/weixin_45392674/article/details/127671658