python关于onnx模型的一些基本操作

最近在对模型进行量化时候,模型格式转变为onnx模型了,因此需要对onnx进行加载、运行以及量化(权重/输入/输出)。故,对onnx模型的相关操作进行简单的学习,并写下了这边博客,若有错误请指出,谢谢。

一、onnx的配置环境

onnx的环境主要包含两个包onnx和onnxruntime,我们可以通过pip安装这两个依赖包。

pip install onnxruntime
pip install onnx

二、获取onnx模型的输出层

import onnx
# 加载模型
model = onnx.load('onnx_model.onnx')
# 检查模型格式是否完整及正确
onnx.checker.check_model(model)
# 获取输出层,包含层名称、维度信息
output = self.model.graph.output
print(output)

三、获取中节点输出数据

  onnx模型通常只能拿到最后输出节点的输出数据,若想拿到中间节点的输出数据,需要我们自己添加相应的输出节点信息;首先需要构建指定的节点(层名称、数据类型、维度信息);然后再通过insert的方式将节点插入到模型中。

import onnx
from onnx import helper
# 加载模型
model = onnx.load('onnx_model.onnx')
# 创建中间节点:层名称、数据类型、维度信息
prob_info =  helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])
# 将构建完成的中间节点插入到模型中
model.graph.output.insert(0, prob_info)
# 保存新的模型
onnx.save(model, 'onnx_model_new.onnx')

# 扩展:
# 删除指定的节点方法: item为需要删除的节点
# model.graph.output.remove(item)

四、onnx前向InferenceSession的使用

  关于onnx的前向推理,onnx使用了onnxruntime计算引擎。
  onnx runtime是一个用于onnx模型的推理引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,顺路提供了一个专门用于ONNX模型推理的引擎(onnxruntime)。

import onnxruntime

# 创建一个InferenceSession的实例,并将模型的地址传递给该实例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 调用实例sess的润方法进行推理
outputs = sess.run(output_layers_name, {
    
    input_layers_name: x})

1. 创建实例,源码分析

class InferenceSession(Session):
    """
    This is the main class used to run a model.
    """
    def __init__(self, path_or_bytes, sess_options=None, providers=[]):
        """
        :param path_or_bytes: filename or serialized model in a byte string
        :param sess_options: session options
        :param providers: providers to use for session. If empty, will use
            all available providers.
        """
        self._path_or_bytes = path_or_bytes
        self._sess_options = sess_options
        self._load_model(providers)
        self._enable_fallback = True
        Session.__init__(self, self._sess)

    def _load_model(self, providers=[]):
        if isinstance(self._path_or_bytes, str):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
                True)
        elif isinstance(self._path_or_bytes, bytes):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
                False)
        # elif isinstance(self._path_or_bytes, tuple):
        # to remove, hidden trick
        #   self._sess.load_model_no_init(self._path_or_bytes[0], providers)
        else:
            raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))

        self._sess.load_model(providers)

        self._sess_options = self._sess.session_options
        self._inputs_meta = self._sess.inputs_meta
        self._outputs_meta = self._sess.outputs_meta
        self._overridable_initializers = self._sess.overridable_initializers
        self._model_meta = self._sess.model_meta
        self._providers = self._sess.get_providers()

        # Tensorrt can fall back to CUDA. All others fall back to CPU.
        if 'TensorrtExecutionProvider' in C.get_available_providers():
            self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        else:
            self._fallback_providers = ['CPUExecutionProvider']

  在_load_model函数,可以发现在load模型的时候是通过C.InferenceSession,并且将相关的操作也委托给该类。从导入语句from onnxruntime.capi import _pybind_state as C可知其实就是一个c++实现的Python接口,其源码在onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc中。

2. 模型推理run,源码分析

    def run(self, output_names, input_feed, run_options=None):
        """
        Compute the predictions.

        :param output_names: name of the outputs
        :param input_feed: dictionary ``{ input_name: input_value }``
        :param run_options: See :class:`onnxruntime.RunOptions`.

        ::

            sess.run([output_name], {input_name: x})
        """
        num_required_inputs = len(self._inputs_meta)
        num_inputs = len(input_feed)
        # the graph may have optional inputs used to override initializers. allow for that.
        if num_inputs < num_required_inputs:
            raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
        if not output_names:
            output_names = [output.name for output in self._outputs_meta]
        try:
            return self._sess.run(output_names, input_feed, run_options)
        except C.EPFail as err:
            if self._enable_fallback:
                print("EP Error: {} using {}".format(str(err), self._providers))
                print("Falling back to {} and retrying.".format(self._fallback_providers))
                self.set_providers(self._fallback_providers)
                # Fallback only once.
                self.disable_fallback()
                return self._sess.run(output_names, input_feed, run_options)
            else:
                raise

  在run函数中,数据的推理是通过调用self._sess.run来进行前向推理的。同理该函数的具体实现实在c++的InferenceSession类中实现的。

五、遇到的一些问题

  1. 输入数据维度或类型不正确
    在这里插入图片描述
      从上图可以看出,该模型的输入数据的维度信息为[1, 3, 480, 640],输入数据类型为float32;所以在构建输入数据时,一定要按照该信息去构建,否则代码将会报错。

注: python调用c++代码是通过pybind11实现。

猜你喜欢

转载自blog.csdn.net/CFH1021/article/details/108732114