将训练好的Pytorch模型修改为可以在Android部署的样式

最近想要把训练好的Pytorch模型在Android端上部署,发现如果将直接训练好的模型直接运用到Android上会出现闪退的情况,所以需要将转模型进行转换。

查找了很多博客,都不能直接解决我的问题,所以经过一天的试错,终于把模型转换搞好了。

参考:
将Pytorch模型部署到Android端
pytorch官方教程

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
import os
from model import ResNet18

MODEL_PATH=''
model_pth = os.path.join(MODEL_PATH, 'test1_dict.pth')  #拼接原模型的路径

#搭建网络,可以自己的网络模型,也可以使用torchvision.model提供的模型
model=ResNet18.RestNet18_Net()

#加载参数
model.load_state_dict(torch.load(model_pth))

#模型设置为评测模式
model.eval()

example=torch.rand(1,3,384,384)

#模型转化
traced_script_module = torch.jit.trace(model, example)

#移动端优化
traced_script_module_optimized = optimize_for_mobile(traced_script_module)

#保存模型
traced_script_module_optimized._save_for_lite_interpreter("model4.pt")

最后在官方案例测试一下
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m0_50127633/article/details/118857311