Jetson Nano 【9】 pytorch 直转tensorRT的模型序列化

前提:

思路:

-  1.python类的序列化(显然不太靠谱),我试了一下,果然不太靠谱
-  2.参考tensorRT官方文档
-  3.参考torch2trt官方git
参考tensorRT官方文档(证明在此份代码不可行,但是是可以序列话的)
  • 思路1太蠢了,直接掠过,直接看思路2,一开始感觉还是比较靠谱的
  • 官方文档传送
  • 来到3.4 用python序列化一个模型
  • 在这里插入图片描述
  • 我们不难找出相关代码
# 序列化
serialized_engine = engine.serialize()
# 序列化并保存
with open(“sample.engine”, “wb”) as f:
		f.write(engine.serialize())

# 反序列化
with trt.Runtime(TRT_LOGGER) as runtime:
	engine = runtime.deserialize_cuda_engine(serialized_engine)
# 从文件中反序列化
with open(“sample.engine”, “rb”) as f, trt.Runtime(TRT_LOGGER) as runtime:
		engine = runtime.deserialize_cuda_engine(f.read())
  • 那么问题来了,这个engine到底是个啥东西?这要从tensorRT build网络模型开始说起~不过这个我还是不太懂,有机会整理一下,就看看它是怎么来的
  • 关于这个内容,我们正好可以借用一下torch2trt的源码
def torch2trt(module, 
              inputs, 
              input_names=None, 
              output_names=None, 
              log_level=trt.Logger.ERROR, 
              max_batch_size=1,
              fp16_mode=False, 
              max_workspace_size=0, 
              strict_type_constraints=False, 
              keep_network=True, 
              int8_mode=False, 
              int8_calib_dataset=None,
              int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM):

    inputs_in = inputs
    
    # copy inputs to avoid modifications to source data
    inputs = [tensor.clone()[0:1] for tensor in inputs]  # only run single entry
    
    logger = trt.Logger(log_level)
    builder = trt.Builder(logger)
    network = builder.create_network()
    
    with ConversionContext(network) as ctx:

        if isinstance(inputs, list):
            inputs = tuple(inputs)
        if not isinstance(inputs, tuple):
            inputs = (inputs, )
        ctx.add_inputs(inputs, input_names)

        outputs = module(*inputs)

        if not isinstance(outputs, tuple) and not isinstance(outputs, list):
            outputs = (outputs, )
        ctx.mark_outputs(outputs, output_names)

    builder.max_workspace_size = max_workspace_size
    builder.fp16_mode = fp16_mode
    builder.max_batch_size = max_batch_size
    builder.strict_type_constraints = strict_type_constraints
    
    if int8_mode:
        
        # default to use input tensors for calibration
        if int8_calib_dataset is None:
            int8_calib_dataset = TensorBatchDataset(inputs_in)
        
        builder.int8_mode = True
        
        # @TODO(jwelsh):  Should we set batch_size=max_batch_size?  Need to investigate memory consumption
        builder.int8_calibrator = DatasetCalibrator(inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm)

    engine = builder.build_cuda_engine(network)
    
    module_trt = TRTModule(engine, ctx.input_names, ctx.output_names)
        
    if keep_network:
        module_trt.network = network
            
    return module_trt
  • engine = builder.build_cuda_engine(network),简单来说,builder 负责构造网络,而engine也是由builder给build出来的
  • 于是可以用官方给出的代码序列化,但是由于torch2trt返回的是TRTModule这个类型,所以,我们只能强行转换里面的engine,在build的时候再吧这个engine传回去,但我的测试发现虽然,这个方法可行(表示序列化成功,模型构建成功),但后续预测会报错,具体原因未知。
  • 下面是代码,trance
serialized_engine = model_trt.engine.serialize()
with open("speed.engine", "wb") as f:
    f.write(serialized_engine)
  • 强行转换会报这个错误

# 强行转换会报这个错误
Traceback (most recent call last):
  File "/home/nano/Desktop/YOLOv3-Torch2TRT/mydetect.py", line 190, in <module>
    detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres, method=2)
  File "/home/nano/Desktop/YOLOv3-Torch2TRT/utils/utils.py", line 254, in non_max_suppression
    image_pred = image_pred[(-score).argsort()]
IndexError: too many indices for tensor of dimension 2
  • 序列化反序列化结果对比(两个engine除了地址那是一毛一样,但是最后就是没成功,果真有些玄学~),但至少说明序列化应该是没啥毛病
    的
参考torch2trt官方git(这份代码适合,是TRTModule类型)
  • 怪我当时没看仔细,ReadME写着了
    在这里插入图片描述
# 序列化
torch.save(model_trt.state_dict(), 'alexnet_trt.pth')

# 读取
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('alexnet_trt.pth'))
  • 序列化没测,反序列化测了一下

    • YOLO v3 tiny 大约24秒加载完毕
    • YOLO v3 大约34秒加载完毕
    • YOLO v3 spp 大约37秒加载完毕
  • 可以正常使用

发布了118 篇原创文章 · 获赞 170 · 访问量 19万+

猜你喜欢

转载自blog.csdn.net/symuamua/article/details/104623718