TensorRT Reasoning Handwritten Digit Classification (3)

Series Article Directory

(1) Use pytorch to build a model and train
(2) Convert pth format to onnx format
(3) Convert onxx format to engine serialization file and perform inference



foreword

  In the previous section, we have successfully converted the pth file to the onnx format file, and verified the exported onnx file, and the result is no problem. In this section, we start with the onnx file, generate the engine file step by step and use tensorrt for inference.


1. What is TensorRT?

  NVIDIA TensorRT™ is an SDK for high-performance deep learning inference. This SDK includes a deep learning inference optimizer and a runtime environment that provides low latency and high throughput for deep learning inference applications. In layman's terms, TensorRT is an inference framework developed by NVIDIA for its own GPU. It uses some algorithms and operations to optimize network inference performance and improve the inference speed of deep learning models on GPU.
insert image description here
We use the TensorRT framework to speed up the inference speed of our handwritten digit classification model.
I also wrote a blog about the installation method of TensorRT: refer to here .

Here we assume that TensorRT has been installed, and the version I installed here is TensorRT-8.0.1.6. Before generating the engine file, first introduce a useful tool trtexec . trtexec is a command-line tool that can help us generate engines without writing code, and many other useful functions. Interested readers can explore by themselves. Here we only use a few common command-line parameters.
For detailed parameters of trtexec, please refer to this blog .

2. How to generate engine through onnx

  To tidy up, we now have the onnx file and tensorrt installed, and now our goal is to generate the engine file. Before the onnx file, we have introduced what it is, so what is the engine file?

The engine file in TensorRT is a binary file that contains an optimized deep learning model. This file can be used for inference without reloading and optimizing the model. When using TensorRT for reasoning, you first need to convert the trained model into a TensorRT engine file, and then use this file for reasoning.

In other words, we only need to generate an engine once, and this engine file contains the optimized model (this optimization is done by TensoRT itself). When reasoning in the future, we only need to load this engine instead of starting from scratch.

Generate engine using trtexec

TensorRT-8.0.1.6/bin/trtexec --onnx=model.onnx --saveEngine=model.engine --buildOnly

Enter this line of instructions on the command line to help us generate model.engine. The trtexec command has many other parameters, which are interesting to understand for yourself. Here we only use –onnx, which means that the input is an onnx file, –saveEngine means to store the engine file, –buildOnly means only build, no reasoning.

Use the python interface

The code is as follows (example):

import os
import tensorrt as trt

onnx_file = '/home/wjq/wjqHD/pytorch_mnist/model.onnx'
nHeight, nWidth = 28, 28
trtFile = '/home/wjq/wjqHD/pytorch_mnist/model.engine'

# Parse network, rebuild network, and build engine, then save engine
logger = trt.Logger(trt.Logger.VERBOSE)

builder = trt.Builder(logger)

network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()

parser = trt.OnnxParser(network, logger)

if not os.path.exists(onnx_file):
    print('ONNX file {} not found.'.format(onnx_file))
    exit()
print("Loading ONNX file from path {}...".format(onnx_file))

with open(onnx_file, 'rb') as model:
    if not parser.parse(model.read()):
        print('ERROR: Failed to parse the ONNX file.')
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        exit()
    
    print("Succeed to parse the ONNX file.")

input_tensor = network.get_input(0)
# 这是输入大小
profile.set_shape(input_tensor.name, [1, 1, nHeight, nWidth], [1, 1, nHeight, nWidth], [1, 1, nHeight, nWidth])
config.add_optimization_profile(profile)

engineString = builder.build_serialized_network(network, config)  # 序列化engine文件
if engineString == None:
    print("Failed building engine!")
    exit()
print("Succeeded building engine!")
with open(trtFile, "wb") as f:
    f.write(engineString)

Using the above python code, we can finally generate an engine file. For the API in this code, you can go to Google to find an explanation. I just show a possibility here. If you have any questions, welcome to communicate in the comment area.

We can also use the trtexec tool to verify that the engine we generated is correct. The command line command is:

TensorRT-8.0.1.6/bin/trtexec --loadEngine=model.engine --exportProfile=layerProfile.json --batch=1 --warmUp=1000 --verbose

–loadEngine is the loaded engine file path, –exportProfile this parameter can output the average running time of each layer in the network and the percentage of the total time, –verbose is to print the log, –warmUp is to warm up the graphics card in advance.

3. Make inferences

  We have obtained the model.engine file. In the last step, we need to use the tensorrt interface to read the engine file and image file for inference to obtain the final classification result.
  Since the python package of pycuda and cuda cannot be installed in my environment now, the last reasoning step will be added when the environment is proper.

Summarize

  In this section, we introduce how to use the trtexec tool and python code to generate an engine file through onnx, and use tensorrt's api interface to call the engine file for inference. TensorRT reasoning handwritten digit classification has a total of three sections. It generally introduces the process of deploying a deep learning model. I hope everyone can gain something. Next, if you have time to prepare to update another job: pytorch encounters an unsupported operator, tensorrt encounters an unsupported operator, and onnx encounters an unsupported operator.

Guess you like

Origin blog.csdn.net/qq_41596730/article/details/132403945