Pytorch模型转成Libtorch模型

将Pytroch模型(.pth)转成Libtorch模型(.pt),代码如下所示:

import torch
from backbone import Net

# 指定模型加载设备
device = torch.device('cpu')

# 创建模型实例
model = Net()

# 加载保存的模型
model.load_state_dict(torch.load('./result/pytorch_model_50.pth', map_location=device))

# 将模型设定为推理模式
model.eval()

# 随机生成一个与网络输入维度一致的tensor
example = torch.randn(1, 3, 224, 224)

# 转成.pt文件,并保存
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save('./result/libtorch_model_50.pt')

猜你喜欢

转载自blog.csdn.net/weixin_48158964/article/details/132346688