modelo de pytorch a onnx

Referencia: sitio web oficial de pytorch

Según el ejemplo del sitio web oficial, este artículo convierte un modelo de clasificación mobilenetV2 en onnx

Principalmente dividido en los siguientes pasos:

1. Importe el modelo, cargue los parámetros de peso y configure el modelo en modo de evaluación.

2. Construya una entrada aleatoria El canal de entrada, la altura y el peso deben ser los mismos que durante el entrenamiento.

3. Exportar onnx.

import io
import torch
import torch.onnx
from MobileNetV2 import mobilenet_v2


torch_model = mobilenet_v2()

state_dict = torch.load("checkpoint/mobilenet-v2_0.pth", map_location='cuda:0')

torch_model.load_state_dict(state_dict)
torch_model.eval()

batch_size = 1
x = torch.randn(batch_size, 3, 128, 128, requires_grad=True)

torch.onnx.export(torch_model, x, "mobilenet_v2.onnx")

 

Supongo que te gusta

Origin blog.csdn.net/Guo_Python/article/details/108821599
Recomendado
Clasificación