pytorch model to onnx

Reference: pytorch official website

According to the example of the official website, this article converts a mobilenetV2 classification model into onnx

Mainly divided into the following steps:

1. Import the model, load the weight parameters, and set the model to eval mode.

2. Construct a random input. The input channel, height, and weight should be the same as during training.

3. Export onnx.

import io
import torch
import torch.onnx
from MobileNetV2 import mobilenet_v2


torch_model = mobilenet_v2()

state_dict = torch.load("checkpoint/mobilenet-v2_0.pth", map_location='cuda:0')

torch_model.load_state_dict(state_dict)
torch_model.eval()

batch_size = 1
x = torch.randn(batch_size, 3, 128, 128, requires_grad=True)

torch.onnx.export(torch_model, x, "mobilenet_v2.onnx")

 

Guess you like

Origin blog.csdn.net/Guo_Python/article/details/108821599