Python环境下将ONNX模型转为fp16 半精度浮点方式

背景

在TX2上和NX上跑自己想要的模型还是有点慢,由于Jetpack4.6.2的TensorRT8.2对于有16G内存的NX支持存在问题运行不了(8G内存没有问题),可以运行的TensorRT7不支持我这边模型用到的einsum操作,所以我先想着改成fp16运行下看看

参考

https://blog.csdn.net/znsoft/article/details/114538684

流程

  1. 参考代码其实挺简单,但是python环境安装过程有点坎坷,建议新建一个虚拟环境来安装,好像有人把环境都直接装崩了
  2. 新建python3.7的虚拟环境,我是新建了基于python3.7的conda环境,注意哈,截至目前20220513来说这个winmltools无法在python3.8安装,build wheel会报错卡住,所以我最后安装的3.7的python,顺带吐槽下这破东西怎么要装这么多个版本的scipy还是啥的,就离谱
  3. 直接命令行安装:
pip install winmltools
  1. 安装好之后大概就可以按照下面代码把模型修改了:
from winmltools.utils import convert_float_to_float16
from winmltools.utils import load_model, save_model
onnx_model = load_model('model.onnx')
new_onnx_model = convert_float_to_float16(onnx_model)
save_model(new_onnx_model, 'model_fp16.onnx')

报错

我这边这个模型碰到了小问题,报错:

(op_type:AveragePool, name:AveragePool_141): Inferred shape and existing shape differ in dimension 2: (8) vs (7)
Traceback (most recent call last):
  File "G:/jupyter/fp16_convert/fp16_convert.py", line 4, in <module>
    new_onnx_model = convert_float_to_float16(onnx_model)
  File "D:\ProgramData\Anaconda\envs\fp16_convert\lib\site-packages\onnxconverter_common\float16.py", line 139, in convert_float_to_float16
    model = func_infer_shape(model)
  File "D:\ProgramData\Anaconda\envs\fp16_convert\lib\site-packages\onnx\shape_inference.py", line 36, in infer_shapes
    inferred_model_str = C.infer_shapes(model_str)
RuntimeError: Inferred shape and existing shape differ in dimension 2: (8) vs (7)

Process finished with exit code 1

由于我是验证过的,可能是其他模型转onnx遇到了点小bug,把它infer那一段跳过就好了。根据报错内容跳转到shape_inference.py中,作如下修改:


def infer_shapes(model):  # type: (ModelProto) -> ModelProto
    if not isinstance(model, ModelProto):
        raise ValueError('Shape inference only accepts ModelProto, '
                         'incorrect type: {}'.format(type(model)))
    model_str = model.SerializeToString()
    return onnx.load_from_string(model_str)
    inferred_model_str = C.infer_shapes(model_str)
    return onnx.load_from_string(inferred_model_str)

重新运行代码,生成成功,放到NX开发板上跑,比float的快了大概1.5倍的样子。

猜你喜欢

转载自blog.csdn.net/weixin_42492254/article/details/124757094
今日推荐