Algumas operações básicas do modelo python onnx

Recentemente, quando o modelo foi quantificado, o formato do modelo foi alterado para o modelo onnx, então onnx precisa ser carregado, executado e quantificado (peso / entrada / saída). Portanto, eu simplesmente aprendi as operações relacionadas ao modelo onnx e escrevi um blog aqui.Se houver algum erro, indique-o, obrigado.

Um, ambiente de configuração onnx

O ambiente onnx contém principalmente dois pacotes, onnx e onnxruntime, e podemos instalar esses dois pacotes dependentes por meio do pip.

pip install onnxruntime
pip install onnx

Em segundo lugar, obtenha a camada de saída do modelo onnx

import onnx
# 加载模型
model = onnx.load('onnx_model.onnx')
# 检查模型格式是否完整及正确
onnx.checker.check_model(model)
# 获取输出层,包含层名称、维度信息
output = self.model.graph.output
print(output)

Três, obtenha os dados de saída do nó intermediário

  O modelo onnx geralmente só pode obter os dados de saída do último nó de saída. Se quisermos obter os dados de saída do nó intermediário, precisamos adicionar nós mesmos as informações do nó de saída correspondente; primeiro, precisamos construir o nó especificado (nome da camada, tipo de dados, informações de dimensão) ; Em seguida, insira o nó no modelo por inserir.

import onnx
from onnx import helper
# 加载模型
model = onnx.load('onnx_model.onnx')
# 创建中间节点:层名称、数据类型、维度信息
prob_info =  helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])
# 将构建完成的中间节点插入到模型中
model.graph.output.insert(0, prob_info)
# 保存新的模型
onnx.save(model, 'onnx_model_new.onnx')

# 扩展:
# 删除指定的节点方法: item为需要删除的节点
# model.graph.output.remove(item)

Quarto, o uso de onnx forward InferenceSession

  Com relação ao raciocínio direto de onnx, onnx usa o mecanismo de cálculo onnxruntime.
  onnx runtime é um mecanismo de inferência para modelos onnx. Em 2017, a Microsoft, junto com o Facebook e outros, desenvolveu um padrão de formato para modelos de aprendizado profundo e aprendizado de máquina - ONNX, e forneceu um mecanismo (onnxruntime) dedicado à inferência de modelos ONNX.

import onnxruntime

# 创建一个InferenceSession的实例,并将模型的地址传递给该实例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 调用实例sess的润方法进行推理
outputs = sess.run(output_layers_name, {
    
    input_layers_name: x})

1. Crie exemplos, análise de código-fonte

class InferenceSession(Session):
    """
    This is the main class used to run a model.
    """
    def __init__(self, path_or_bytes, sess_options=None, providers=[]):
        """
        :param path_or_bytes: filename or serialized model in a byte string
        :param sess_options: session options
        :param providers: providers to use for session. If empty, will use
            all available providers.
        """
        self._path_or_bytes = path_or_bytes
        self._sess_options = sess_options
        self._load_model(providers)
        self._enable_fallback = True
        Session.__init__(self, self._sess)

    def _load_model(self, providers=[]):
        if isinstance(self._path_or_bytes, str):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
                True)
        elif isinstance(self._path_or_bytes, bytes):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
                False)
        # elif isinstance(self._path_or_bytes, tuple):
        # to remove, hidden trick
        #   self._sess.load_model_no_init(self._path_or_bytes[0], providers)
        else:
            raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))

        self._sess.load_model(providers)

        self._sess_options = self._sess.session_options
        self._inputs_meta = self._sess.inputs_meta
        self._outputs_meta = self._sess.outputs_meta
        self._overridable_initializers = self._sess.overridable_initializers
        self._model_meta = self._sess.model_meta
        self._providers = self._sess.get_providers()

        # Tensorrt can fall back to CUDA. All others fall back to CPU.
        if 'TensorrtExecutionProvider' in C.get_available_providers():
            self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        else:
            self._fallback_providers = ['CPUExecutionProvider']

  Na função _load_model, pode-se descobrir que C.InferenceSession é usado ao carregar o modelo, e as operações relacionadas também são delegadas a esta classe. A partir da instrução import de onnxruntime.capi import _pybind_state as C, pode-se ver que na verdade é uma interface Python implementada por C ++ e seu código-fonte está em onnxruntime \ onnxruntime \ python \ onnxruntime_pybind_state.cc.

2. Modelo de raciocínio executado, análise de código-fonte

    def run(self, output_names, input_feed, run_options=None):
        """
        Compute the predictions.

        :param output_names: name of the outputs
        :param input_feed: dictionary ``{ input_name: input_value }``
        :param run_options: See :class:`onnxruntime.RunOptions`.

        ::

            sess.run([output_name], {input_name: x})
        """
        num_required_inputs = len(self._inputs_meta)
        num_inputs = len(input_feed)
        # the graph may have optional inputs used to override initializers. allow for that.
        if num_inputs < num_required_inputs:
            raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
        if not output_names:
            output_names = [output.name for output in self._outputs_meta]
        try:
            return self._sess.run(output_names, input_feed, run_options)
        except C.EPFail as err:
            if self._enable_fallback:
                print("EP Error: {} using {}".format(str(err), self._providers))
                print("Falling back to {} and retrying.".format(self._fallback_providers))
                self.set_providers(self._fallback_providers)
                # Fallback only once.
                self.disable_fallback()
                return self._sess.run(output_names, input_feed, run_options)
            else:
                raise

  Na função run, a inferência de dados é a inferência direta chamando self._sess.run. Da mesma forma, a implementação específica desta função é implementada na classe InferenceSession do C ++.

Cinco, alguns problemas encontrados

  1. A dimensão ou tipo de dados de entrada está incorreto
    Insira a descrição da imagem aqui
      . Como pode ser visto na figura acima, as informações de dimensão dos dados de entrada do modelo são [1, 3, 480, 640] e o tipo de dados de entrada é float32; portanto, ao construir os dados de entrada, você deve seguir isto Informações a serem construídas, caso contrário, o código relatará um erro.

Nota: Python chama o código C ++ por meio de pybind11.

Acho que você gosta

Origin blog.csdn.net/CFH1021/article/details/108732114
Recomendado
Clasificación