ONNX builds and runs the model

        ONNX is the abbreviation of Open Neural Network Exchange, which is mainly created and maintained by Microsoft and the partner community. The models of many deep learning training frameworks (such as Tensorflow, PyTorch, Scikit-learn, MXNet, etc.) can be exported or converted to the standard ONNX format. Using the ONNX format as a unified interface, various embedded platforms only need to parse ONNX format model without supporting a variety of training frameworks. This article mainly introduces how to construct an ONNX single-operator model or the entire graph in the form of code or JSON files, and use ONNX Runtime for inference to obtain the calculation results of the operator or model. .

1. ONNX file format

        ONNX files are serialized based on Protobuf. Students who understand the Protobuf protocol should know that Protobuf will have a *.proto file definition protocol, and the protocol of ONNX is defined in the onnx/onnx.proto3 at main · onnx/onnx · GitHub  file.

        The key data structures we need to know from the onnx.proto3 protocol are as follows:

  • ModelProto: The definition of the model, including version information, producer and GraphProto.
  • GraphProto: Contains many repeated NodeProto, initializer, ValueInfoProto, etc. These elements together form a calculation graph. In GraphProto, these elements are stored in the form of a list, and the connection relationship is expressed through the input and output between Nodes.
  • NodeProto: The calculation graph of onnx is a directed acyclic graph (DAG). NodeProto defines the operator type, the input and output of the node, and also includes attributes.
  • ValueInforProto: Defines the type of input and output variables.
  • TensorProto: Serialized weight data, including data type, shape, etc.
  • AttributeProto: An attribute with a name, which can store basic data types (int, float, string, vector, etc.) or data structures defined by onnx (TENSOR, GRAPH, etc.).

二. Python APIs

2.1 Build ONNX model

        ONNX uses DAG to describe the network structure, that is, a network (Graph) is composed of nodes (Node) and edges (Tensor). There are many APIs in the helper class provided by ONNX that can be used to build an ONNX network model, such as make_node, make_graph, make_tensor, etc. The following is an example of a single Conv2d network construction:

import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np
weight = np.random.randn(36)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2, 4, 4])
W = helper.make_tensor('W', TensorProto.FLOAT, [2, 2, 3, 3], weight)
B = helper.make_tensor('B', TensorProto.FLOAT, [2], [1.0, 2.0])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 2, 2, 2])
node_def = helper.make_node(
'Conv', # node name
['X', 'W', 'B'],
['Y'], # outputs
# attributes
strides=[2,2],
)
graph_def = helper.make_graph(
[node_def],
'test_conv_mode',
[X], # graph inputs
[Y], # graph outputs
initializer=[W, B],
)
mode_def = helper.make_model(graph_def, producer_name='onnx-example')
onnx.checker.check_model(mode_def)
onnx.save(mode_def, "./Conv.onnx")

        The built Conv operator model is visualized using netron as shown in the figure below:

         This example demonstrates how to use the helper's make_tensor_value_info, make_mode, make_graph, make_model and other methods to build an onnx model.

        Compared with PyTorch or other frameworks, these APIs still seem relatively cumbersome. Generally, we will not use ONNX to build a large network model, but convert it through other frameworks to obtain an ONNX model.

2.2 Shape Inference

        Many times the intermediate nodes of the onnx model we convert from pytorch, tensorflow or other frameworks do not have shape information, as shown in the following figure:

 

         We often hope to directly see the shape information of certain nodes in the network. The shape_inference module can derive the shape information of all nodes, which will be more friendly when visualizing the model:

import onnx
from onnx import shape_inference
onnx_model = onnx.load("./test_data/mobilenetv2-1.0.onnx")
onnx_model = shape_inference.infer_shapes(onnx_model)
onnx.save(onnx_model, "./test_data/mobilenetv2-1.0_shaped.onnx")

        The model after visualizing shape_inference is as follows:

2.3 ONNX Optimizer

        ONNX's optimizer module provides partial graph optimization functions, such as the most commonly used ones: fuse_bn_into_conv, fuse_pad_into_conv, etc.

        View the optimization methods supported by onnx:

from onnx import optimizer
all_passes = optimizer.get_available_passes()
print("Available optimization passes:")
for p in all_passes:
print(p)
print()

        Apply graph optimization to the onnx model for transformation:

passes = ['fuse_bn_into_conv']
# Apply the optimization on the original model
optimized_model = optimizer.optimize(onnx_model, passes)

        After applying fuse_bn_into_conv to mobile net v2, the parameters of BatchNormalization are merged into the weight and bias parameters of Conv, as shown in the following figure:

 

3. ONNX Runtime calculates the ONNX model

        onnx itself is just a protocol that defines operators and model structures, etc., and does not involve specific calculations. onnx runtime is an interpreter similar to JVM that runs ONNX format models, including model parsing, graph optimization, back-end operation, etc.

        Install onnx runtime:

python3 -m pip install onnxruntime

        reasoning:

import onnx
import onnxruntime as ort
import numpy as np
import cv2
def preprocess(img_data):
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[0]):
# for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalize
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
return norm_img_data
img = cv2.imread("test_data/dog.jpeg")
img = cv2.resize(img, (224,224), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_data = np.transpose(img, (2, 0, 1))
input_data = preprocess(input_data)
input_data = input_data.reshape([1, 3, 224, 224])
sess = ort.InferenceSession("test_data/mobilenetv2-1.0.onnx")
input_name = sess.get_inputs()[0].name
result = sess.run([], {input_name: input_data})
result = np.reshape(result, [1, -1])
index = np.argmax(result)
print("max index:", index)

Guess you like

Origin blog.csdn.net/soralaro/article/details/127927783