Use onnx y onnxruntime para completar la implementación del modelo

Definición de implementación del modelo

La implementación del modelo de aprendizaje profundo se refiere al proceso de ejecutar un modelo entrenado en un entorno específico.

La canalización para la implementación del modelo es la siguiente:

  1. Use cualquier marco de aprendizaje profundo para definir la estructura de la red y entrenar el modelo
  2. La estructura de red y los parámetros del modelo entrenado se convertirán en una representación intermedia que solo describe la estructura de la red (p. ej., onnx, torchscript, etc.), y se realizarán algunas optimizaciones específicas del modelo en la representación intermedia (p. ej., onnxsimplify, leer el modelo onnx, convertir algunos valores que deben calcularse dinámicamente en valores estáticos, lo que simplifica el modelo; recorte de nodos onnx, etc.)
  3. Use un marco de programación de alto rendimiento orientado al hardware (como CUDA, OpenCL, etc.) para escribir un motor de inferencia que pueda ejecutar operadores de aprendizaje profundo de manera eficiente, convertir la representación intermedia en un formato de archivo específico y ejecutar el modelo de manera eficiente en la plataforma de hardware correspondiente

Ejemplo de implementación del modelo

1. Cree un entorno virtual para la implementación de modelos

conda create -n modeldeploy python=3.8 -y
conda activate modeldeploy

2. Instale las bibliotecas de terceros necesarias para la implementación del modelo

conda install pytorch torchvision cpuonly -c pytorch
pip install onnxruntime onnx opencv-python

3. Defina el modelo pytorch de superresolución SRCNN y pruebe el modelo

"""
该代码来自: https://zhuanlan.zhihu.com/p/477743341
"""

import cv2
import numpy as np
import torch
import torch.nn as nn


class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor, 
            mode='bicubic', 
            align_corners=False)
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out


def init_torch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)

    state_dict = torch.load('srcnn.pth')['state_dict']

    # Adapt the checkpoint 
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model

model = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)
input_img = cv2.resize(input_img, [256, 256])

# HWC to NCHW 
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

# Inference 
torch_output = model(torch.from_numpy(input_img)).detach().numpy()

# NCHW to HWC 
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
 
# Show image 
cv2.imwrite("face_torch.png", torch_output)

4. Convierta el modelo pth en modelo onnx

# convert pth to onnx
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    # opset_version为onnx算子集的版本, 版本越高, 支持的算子越多
    torch.onnx.export(model, x, "srcnn.onnx", opset_version=11, 
                    input_names=['input'], output_names=['output'])

5. Compruebe si el archivo del modelo onnx convertido es correcto

# verify
import onnx
onnx_model = onnx.load("srcnn.onnx") 
try: 
    onnx.checker.check_model(onnx_model) 
except Exception: 
    print("Model incorrect") 
else: 
    print("Model correct")


"""
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Model correct
"""

6. Use el motor de inferencia onnxruntime para inferir el modelo onnx

# inference
import onnxruntime
ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]
ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
print(torch.equal(torch.from_numpy(ort_output), torch.from_numpy(torch_output)))

Artículo de referencia

Tutorial de introducción a la implementación de modelos (1): Introducción a la implementación de modelos

Supongo que te gusta

Origin blog.csdn.net/qq_38964360/article/details/131780719
Recomendado
Clasificación