Pytorch モデルを Libtorch モデルに変換する

Pytroch モデル (.pth) を Libtorch モデル (.pt) に変換するコードは次のとおりです。

import torch
from backbone import Net

# 指定模型加载设备
device = torch.device('cpu')

# 创建模型实例
model = Net()

# 加载保存的模型
model.load_state_dict(torch.load('./result/pytorch_model_50.pth', map_location=device))

# 将模型设定为推理模式
model.eval()

# 随机生成一个与网络输入维度一致的tensor
example = torch.randn(1, 3, 224, 224)

# 转成.pt文件,并保存
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save('./result/libtorch_model_50.pt')

おすすめ

転載: blog.csdn.net/weixin_48158964/article/details/132346688