首先,需要掌握libtorch的一些语法,可以参考下面的链接:
[https://www.cnblogs.com/yanghailin/p/12901586.html]
大概说下pytorch转libtorch流程:
1.先训练pytorch的模型,并测试
2.把pytorch模型转pt
3.写后处理
2.把pytorch模型转pt
这个可以单独写个脚本,也可以在跑测试脚本的时候在其中某个位置加上两句话就可以了。
单独写脚本例子如下:
import torch
from net import resnet
# Seg model
model = resnet()
state_dict = torch.load("e130_i391.pth")
model.load_state_dict(state_dict, strict=True)
for p in model.parameters():
p.requires_grad = False
model.eval()
model = model.cpu()
example = torch.rand(1, 3, 48, 640)
traced_script_module = torch.jit.trace(model, example)
print(traced_script_module)
traced_script_module.save("./01077cls.pt")
我一般喜欢直接拿测试脚本在某处加上两句话:
#########################
traced_script_module = torch.jit.trace(net, x)
traced_script_module.save("RefineDet.PyTorch-master/save_pt/refinedet320_0522_0_pytorch1_0.pt")
print("sys.exit(1)")
sys.exit(1)
#################
其中net就是初始化好的网络,x就是输入。如此即可生成pt
3.写后处理
后处理就是网络输出来模型的推理结果,要根据推理结果得到自己所需要的,可以仿照pytorch实现,一句一句的把python语句翻译成libtorch就可以。
在转refinedet的libtorch的时候并不是一帆风顺的,遇到了网上没有解答的错误!就是训练好的模型转pt报错!!!报错如下:
Finished loading model!
/opt/conda/conda-bld/pytorch_1573049310284/work/torch/csrc/autograd/python_function.cpp:622: UserWarning: Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
/opt/conda/conda-bld/pytorch_1573049310284/work/torch/csrc/autograd/python_function.cpp:622: UserWarning: Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
Traceback (most recent call last):
File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/eval_refinedet_320.py", line 480, in <module>
thresh=args.confidence_threshold)
File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/eval_refinedet_320.py", line 402, in test_net
output_names=['output'])
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/__init__.py", line 26, in _export
result = utils._export(*args, **kwargs)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 382, in _export
fixed_batch_size=fixed_batch_size)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 249, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 206, in _trace_and_get_graph_from_model
trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/jit/__init__.py", line 275, in get_trace_graph
return LegacyTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/jit/__init__.py", line 352, in forward
out = self.inner(*trace_inputs)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 539, in __call__
result = self._slow_forward(*input, **kwargs)
File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 525, in _slow_forward
result = self.forward(*input, **kwargs)
File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/models/refinedet.py", line 208, in forward
self.priors.type(type(x.data)) # default boxes
RuntimeError: Attempted to trace Detect_RefineDet, but tracing of legacy functions is not supported
哎,没有解答,大概就是说哪个不支持转pt。仔细研究了一下,在refinedet.py的一句
if self.phase == "test":
#print(loc, conf)
output = self.detect(
arm_loc.view(arm_loc.size(0), -1, 4), # arm loc preds
self.softmax(arm_conf.view(arm_conf.size(0), -1,
2)), # arm conf preds
odm_loc.view(odm_loc.size(0), -1, 4), # odm loc preds
self.softmax(odm_conf.view(odm_conf.size(0), -1,
self.num_classes)), # odm conf preds
self.priors.type(type(x.data)) # default boxes
)
这个self.detect调用如下的函数:
RefineDet.PyTorch-master/layers/functions/detection.py,
这个函数开头
import torch
from torch.autograd import Function
from ..box_utils import decode, nms
from data import voc as cfg
好像就是由于from torch.autograd import Function这个导致不支持转的。。。
弄了好久,然后我发现
output = self.detect(
arm_loc.view(arm_loc.size(0), -1, 4), # arm loc preds
self.softmax(arm_conf.view(arm_conf.size(0), -1,
2)), # arm conf preds
odm_loc.view(odm_loc.size(0), -1, 4), # odm loc preds
self.softmax(odm_conf.view(odm_conf.size(0), -1,
self.num_classes)), # odm conf preds
self.priors.type(type(x.data)) # default boxes
)
函数传进去的前4个值其实就是模型推理出来的结果,detect其实就是后处理,那么我直接输出这4个值就好了,后处理自己写!然后试了一下:在refinedet.py相应位置改成如下:
if self.phase == "test":
output = (arm_loc.view(arm_loc.size(0), -1, 4),
self.softmax(arm_conf.view(arm_conf.size(0), -1, 2)),
odm_loc.view(odm_loc.size(0), -1, 4),
self.softmax(odm_conf.view(odm_conf.size(0), -1, self.num_classes))
)
#print(loc, conf)
#output = self.detect(
# arm_loc.view(arm_loc.size(0), -1, 4), # arm loc preds
# self.softmax(arm_conf.view(arm_conf.size(0), -1,
# 2)), # arm conf preds
# odm_loc.view(odm_loc.size(0), -1, 4), # odm loc preds
# self.softmax(odm_conf.view(odm_conf.size(0), -1,
# self.num_classes)), # odm conf preds
# self.priors.type(type(x.data)) # default boxes
#)
果真可以!pt生成出来了!!!
然后折腾了好久,根据pt仿照这pytorch的后处理自己用libtorch写对应的后处理。查找资料,参考,弄了一个星期吧,终于给折腾出来了,并且建了我的第一个github,把我弄的上传github。链接如下:
[https://github.com/wuzuowuyou/libtorch_RefineDet_2020]
libtorch显存用的真少,320图片才860M。
github是cuda8,pytorch1.0的
我一开始是用libtorch1.3 pytorch1.3 cuda10.0实现的,下载链接如下
https://download.csdn.net/download/yang332233/12461623