Convert pytorch model to pt (c++ callable)

Reference: pytorch official website

The pytorch model is converted into a C++ callable serialized file, and the steps are as follows (using the trace method):

1. Load the model parameters and set it to eval mode.

2. Define the input, the same size as the input when training the model.

3. Perform a forward propagation, record the trace, and save the pt file.

import torch
from MobileNetV2 import mobilenet_v2


model = mobilenet_v2()
state_dict = torch.load("checkpoint/mobilenet-v2_0.pth")
model.load_state_dict(state_dict)
model.eval()


x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('mobilenet_v2.pt')

 

Guess you like

Origin blog.csdn.net/Guo_Python/article/details/108825011
pt