Convert Pytorch model to Libtorch model

Convert the Pytroch model (.pth) to the Libtorch model (.pt), the code is as follows:

import torch
from backbone import Net

# 指定模型加载设备
device = torch.device('cpu')

# 创建模型实例
model = Net()

# 加载保存的模型
model.load_state_dict(torch.load('./result/pytorch_model_50.pth', map_location=device))

# 将模型设定为推理模式
model.eval()

# 随机生成一个与网络输入维度一致的tensor
example = torch.randn(1, 3, 224, 224)

# 转成.pt文件,并保存
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save('./result/libtorch_model_50.pt')

Guess you like

Origin blog.csdn.net/weixin_48158964/article/details/132346688