Referencia: sitio web oficial de pytorch
Según el ejemplo del sitio web oficial, este artículo convierte un modelo de clasificación mobilenetV2 en onnx
Principalmente dividido en los siguientes pasos:
1. Importe el modelo, cargue los parámetros de peso y configure el modelo en modo de evaluación.
2. Construya una entrada aleatoria El canal de entrada, la altura y el peso deben ser los mismos que durante el entrenamiento.
3. Exportar onnx.
import io
import torch
import torch.onnx
from MobileNetV2 import mobilenet_v2
torch_model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth", map_location='cuda:0')
torch_model.load_state_dict(state_dict)
torch_model.eval()
batch_size = 1
x = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
torch.onnx.export(torch_model, x, "mobilenet_v2.onnx")