Definición de implementación del modelo
La implementación del modelo de aprendizaje profundo se refiere al proceso de ejecutar un modelo entrenado en un entorno específico.
La canalización para la implementación del modelo es la siguiente:
- Use cualquier marco de aprendizaje profundo para definir la estructura de la red y entrenar el modelo
- La estructura de red y los parámetros del modelo entrenado se convertirán en una representación intermedia que solo describe la estructura de la red (p. ej., onnx, torchscript, etc.), y se realizarán algunas optimizaciones específicas del modelo en la representación intermedia (p. ej., onnxsimplify, leer el modelo onnx, convertir algunos valores que deben calcularse dinámicamente en valores estáticos, lo que simplifica el modelo; recorte de nodos onnx, etc.)
- Use un marco de programación de alto rendimiento orientado al hardware (como CUDA, OpenCL, etc.) para escribir un motor de inferencia que pueda ejecutar operadores de aprendizaje profundo de manera eficiente, convertir la representación intermedia en un formato de archivo específico y ejecutar el modelo de manera eficiente en la plataforma de hardware correspondiente
Ejemplo de implementación del modelo
1. Cree un entorno virtual para la implementación de modelos
conda create -n modeldeploy python=3.8 -y
conda activate modeldeploy
2. Instale las bibliotecas de terceros necesarias para la implementación del modelo
conda install pytorch torchvision cpuonly -c pytorch
pip install onnxruntime onnx opencv-python
3. Defina el modelo pytorch de superresolución SRCNN y pruebe el modelo
"""
该代码来自: https://zhuanlan.zhihu.com/p/477743341
"""
import cv2
import numpy as np
import torch
import torch.nn as nn
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor):
super().__init__()
self.upscale_factor = upscale_factor
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bicubic',
align_corners=False)
self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.img_upsampler(x)
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out
def init_torch_model():
torch_model = SuperResolutionNet(upscale_factor=3)
state_dict = torch.load('srcnn.pth')['state_dict']
# Adapt the checkpoint
for old_key in list(state_dict.keys()):
new_key = '.'.join(old_key.split('.')[1:])
state_dict[new_key] = state_dict.pop(old_key)
torch_model.load_state_dict(state_dict)
torch_model.eval()
return torch_model
model = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)
input_img = cv2.resize(input_img, [256, 256])
# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)
# Inference
torch_output = model(torch.from_numpy(input_img)).detach().numpy()
# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
# Show image
cv2.imwrite("face_torch.png", torch_output)
4. Convierta el modelo pth en modelo onnx
# convert pth to onnx
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
# opset_version为onnx算子集的版本, 版本越高, 支持的算子越多
torch.onnx.export(model, x, "srcnn.onnx", opset_version=11,
input_names=['input'], output_names=['output'])
5. Compruebe si el archivo del modelo onnx convertido es correcto
# verify
import onnx
onnx_model = onnx.load("srcnn.onnx")
try:
onnx.checker.check_model(onnx_model)
except Exception:
print("Model incorrect")
else:
print("Model correct")
"""
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Model correct
"""
6. Use el motor de inferencia onnxruntime para inferir el modelo onnx
# inference
import onnxruntime
ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]
ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
print(torch.equal(torch.from_numpy(ort_output), torch.from_numpy(torch_output)))
【Artículo de referencia】