Reference: pytorch official website
The pytorch model is converted into a C++ callable serialized file, and the steps are as follows (using the trace method):
1. Load the model parameters and set it to eval mode.
2. Define the input, the same size as the input when training the model.
3. Perform a forward propagation, record the trace, and save the pt file.
import torch
from MobileNetV2 import mobilenet_v2
model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth")
model.load_state_dict(state_dict)
model.eval()
x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('mobilenet_v2.pt')