Reference: pytorch official website
According to the example of the official website, this article converts a mobilenetV2 classification model into onnx
Mainly divided into the following steps:
1. Import the model, load the weight parameters, and set the model to eval mode.
2. Construct a random input. The input channel, height, and weight should be the same as during training.
3. Export onnx.
import io
import torch
import torch.onnx
from MobileNetV2 import mobilenet_v2
torch_model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth", map_location='cuda:0')
torch_model.load_state_dict(state_dict)
torch_model.eval()
batch_size = 1
x = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
torch.onnx.export(torch_model, x, "mobilenet_v2.onnx")