pytorch refinedet libtorch实现

首先,需要掌握libtorch的一些语法,可以参考下面的链接:
[https://www.cnblogs.com/yanghailin/p/12901586.html]

大概说下pytorch转libtorch流程:

1.先训练pytorch的模型,并测试

2.把pytorch模型转pt

3.写后处理

2.把pytorch模型转pt

这个可以单独写个脚本,也可以在跑测试脚本的时候在其中某个位置加上两句话就可以了。
单独写脚本例子如下:

import torch
from net import resnet
# Seg model
model = resnet()
state_dict = torch.load("e130_i391.pth")
model.load_state_dict(state_dict, strict=True)
for p in model.parameters():
    p.requires_grad = False
model.eval()
model = model.cpu()

example = torch.rand(1, 3, 48, 640)
traced_script_module = torch.jit.trace(model, example)
print(traced_script_module)
traced_script_module.save("./01077cls.pt")

我一般喜欢直接拿测试脚本在某处加上两句话:

#########################
        traced_script_module = torch.jit.trace(net, x)
        traced_script_module.save("RefineDet.PyTorch-master/save_pt/refinedet320_0522_0_pytorch1_0.pt")
        print("sys.exit(1)")
        sys.exit(1)
#################

其中net就是初始化好的网络,x就是输入。如此即可生成pt

3.写后处理

后处理就是网络输出来模型的推理结果,要根据推理结果得到自己所需要的,可以仿照pytorch实现,一句一句的把python语句翻译成libtorch就可以。

在转refinedet的libtorch的时候并不是一帆风顺的,遇到了网上没有解答的错误!就是训练好的模型转pt报错!!!报错如下:

Finished loading model!
/opt/conda/conda-bld/pytorch_1573049310284/work/torch/csrc/autograd/python_function.cpp:622: UserWarning: Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
/opt/conda/conda-bld/pytorch_1573049310284/work/torch/csrc/autograd/python_function.cpp:622: UserWarning: Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
Traceback (most recent call last):
  File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/eval_refinedet_320.py", line 480, in <module>
    thresh=args.confidence_threshold)
  File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/eval_refinedet_320.py", line 402, in test_net
    output_names=['output'])
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/__init__.py", line 26, in _export
    result = utils._export(*args, **kwargs)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 382, in _export
    fixed_batch_size=fixed_batch_size)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 249, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/onnx/utils.py", line 206, in _trace_and_get_graph_from_model
    trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/jit/__init__.py", line 275, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/jit/__init__.py", line 352, in forward
    out = self.inner(*trace_inputs)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 539, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/data_1/Yang/software_install/Anaconda1105/envs/DB_cuda10_2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 525, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/data_1/Yang/project_new/2020/pytorch_refinedet/RefineDet.PyTorch-master/models/refinedet.py", line 208, in forward
    self.priors.type(type(x.data))                  # default boxes
RuntimeError: Attempted to trace Detect_RefineDet, but tracing of legacy functions is not supported

哎,没有解答,大概就是说哪个不支持转pt。仔细研究了一下,在refinedet.py的一句

     if self.phase == "test":
            #print(loc, conf)
            output = self.detect(
                arm_loc.view(arm_loc.size(0), -1, 4),           # arm loc preds
                self.softmax(arm_conf.view(arm_conf.size(0), -1,
                             2)),                               # arm conf preds
                odm_loc.view(odm_loc.size(0), -1, 4),           # odm loc preds
                self.softmax(odm_conf.view(odm_conf.size(0), -1,
                             self.num_classes)),                # odm conf preds
                self.priors.type(type(x.data))                  # default boxes
            )

这个self.detect调用如下的函数:
RefineDet.PyTorch-master/layers/functions/detection.py,
这个函数开头

import torch
from torch.autograd import Function
from ..box_utils import decode, nms
from data import voc as cfg

好像就是由于from torch.autograd import Function这个导致不支持转的。。。
弄了好久,然后我发现

output = self.detect(
                arm_loc.view(arm_loc.size(0), -1, 4),           # arm loc preds
                self.softmax(arm_conf.view(arm_conf.size(0), -1,
                             2)),                               # arm conf preds
                odm_loc.view(odm_loc.size(0), -1, 4),           # odm loc preds
                self.softmax(odm_conf.view(odm_conf.size(0), -1,
                             self.num_classes)),                # odm conf preds
                self.priors.type(type(x.data))                  # default boxes
            )

函数传进去的前4个值其实就是模型推理出来的结果,detect其实就是后处理,那么我直接输出这4个值就好了,后处理自己写!然后试了一下:在refinedet.py相应位置改成如下:

        if self.phase == "test":
            output = (arm_loc.view(arm_loc.size(0), -1, 4),
                      self.softmax(arm_conf.view(arm_conf.size(0), -1, 2)),
                      odm_loc.view(odm_loc.size(0), -1, 4),
                      self.softmax(odm_conf.view(odm_conf.size(0), -1, self.num_classes))
                      )
            #print(loc, conf)
            #output = self.detect(
            #    arm_loc.view(arm_loc.size(0), -1, 4),           # arm loc preds
            #    self.softmax(arm_conf.view(arm_conf.size(0), -1,
            #                 2)),                               # arm conf preds
            #    odm_loc.view(odm_loc.size(0), -1, 4),           # odm loc preds
            #    self.softmax(odm_conf.view(odm_conf.size(0), -1,
            #                 self.num_classes)),                # odm conf preds
            #    self.priors.type(type(x.data))                  # default boxes
            #)

果真可以!pt生成出来了!!!
然后折腾了好久,根据pt仿照这pytorch的后处理自己用libtorch写对应的后处理。查找资料,参考,弄了一个星期吧,终于给折腾出来了,并且建了我的第一个github,把我弄的上传github。链接如下:
[https://github.com/wuzuowuyou/libtorch_RefineDet_2020]
libtorch显存用的真少,320图片才860M。
github是cuda8,pytorch1.0的

我一开始是用libtorch1.3 pytorch1.3 cuda10.0实现的,下载链接如下
https://download.csdn.net/download/yang332233/12461623

猜你喜欢

转载自www.cnblogs.com/yanghailin/p/12965695.html