ONNX模型及自定义plugin的动态链接库转TensorRT模型推理

ONNX是一种用于表示机器学习模型的格式,而TensorRT是一个高性能的推理引擎,用于在NVIDIA GPU上进行推理。自定义plugin则是指在TensorRT中自定义一些操作(如卷积、ReLU等),以提高模型推理效率。下面是转换ONNX模型到TensorRT并使用自定义plugin的详细步骤:

  1. 定义自定义plugin
    首先,需要定义一个继承自ITensorRT接口的plugin类,实现其对应的虚函数,例如前向计算函数和反向传播函数等。可以使用C++或Python来编写此类。然后,使用PluginRegistry类将其添加到TensorRT中。

  2. 转换ONNX模型到TensorRT
    使用TensorRT的Python API或C++ API,将ONNX模型转换为TensorRT引擎。这可以通过以下步骤完成:

  • 使用OnnxParser类或OnnxConfig类解析ONNX模型。OnnxConfig类允许设置TensorRT引擎和推理的各种配置参数。
  • 创建一个Builder对象,该对象用于构建TensorRT引擎。
  • 通过Builder对象创建一个Network对象,该对象用于构建网络结构。
  • 将解析的ONNX模型添加到Network对象中。
  • 创建一个ICudaEngine对象,其通过Builder对象和Network对象构建TensorRT引擎。
  1. 应用自定义plugin
    使用TensorRT的C++ API或Python API,将自定义plugin应用到TensorRT引擎。可以使用以下步骤实现:
  • 通过ICudaEngine对象获取IPluginRegistry对象,并使用registerPlugin()方法将自定义plugin注册到TensorRT中。
  • 创建一个INetworkDefinition对象,该对象用于构建网络结构。
  • 通过ICudaEngine对象获取一个IExecutionContext对象,用于执行推理。
  • 使用INetworkDefinition对象创建一个输入和输出张量,并将其绑定到IExecutionContext对象中。
  • 执行推理。

在实现以上步骤时,需要注意TensorRT的版本和系统配置。建议在NVIDIA GPU上使用最新版本的TensorRT,以获得最佳性能和功能。

int onnx_with_plugin_create_engine(std::string root_dir)
{
    
    
    std::string onnx_file = root_dir+"model.onnx";
    std::string modeltrt =  root_dir+"model.trt";
    std::string plugin_file =  root_dir+"libvit_plugin.so";
    std::fstream trtCache(modeltrt, std::ifstream::in);
    nvinfer1::ICudaEngine* engine_ = nullptr;
    // Load plugin library
    void* pluginLibrary = dlopen(plugin_file.c_str(), RTLD_LAZY);
    if (!pluginLibrary) {
    
    
        std::cerr << "ERROR: Could not load plugin dynamic library" << std::endl;
        return EXIT_FAILURE;
    }

    // Register plugin factory with TensorRT
    auto creator = getPluginRegistry()->getPluginCreator("TransformerPlugin", "1");
    if (!creator) {
    
    
        std::cerr << "Failed to find plugin creator." << std::endl;
        return EXIT_FAILURE;
    }
    
    if (!trtCache.is_open())
    {
    
    
        std::cout << "Building TRT engine." << std::endl;

        // define builder
        auto builder = (nvinfer1::createInferBuilder(gLogger));
        // define network 
        const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
        auto network = (builder->createNetworkV2(explicitBatch));
        // define onnxparser
        auto parser = (nvonnxparser::createParser(*network, gLogger));
        if (!parser->parseFromFile(onnx_file.data(), static_cast<int>(nvinfer1::ILogger::Severity::kWARNING)))
        {
    
    
            std::cerr << ": failed to parse onnx model file, please check the onnx version and trt support op!"
                      << std::endl;
            exit(-1);
        }
        // define config
        auto networkConfig = builder->createBuilderConfig();
        // setFlag FP16
        // networkConfig->setFlag(nvinfer1::BuilderFlag::kFP16);
        // std::cout << "Enable fp16!" << std::endl;

        // set max batch size
        builder->setMaxBatchSize(1);
        // set max workspace
        networkConfig->setMaxWorkspaceSize(size_t(1) << 30);

        engine_ = (builder->buildEngineWithConfig(*network, *networkConfig));
        if (engine_ == nullptr)
        {
    
    
            std::cerr << ": engine init null." << std::endl;
            exit(-1);
        }
        // serialize the engine, then close everything down
        auto trtModelStream = (engine_->serialize());
        std::fstream trtOut(modeltrt, std::ifstream::out);
        if (!trtOut.is_open())
        {
    
    
            std::cerr << "can't store trt engine.\n";
            exit(-1);
        }
        trtOut.write((char*)trtModelStream->data(), trtModelStream->size());
        trtOut.close();

        
        trtModelStream->destroy();
        networkConfig->destroy();
        parser->destroy();
        network->destroy();
        builder->destroy();
        std::cerr << "build engine done." << std::endl;
    }
    else
    {
    
    
        std::cout << "Load engine: " << modeltrt << std::endl;
        std::ifstream engineFile(modeltrt, std::ios::binary);
        long int fsize = 0;
        engineFile.seekg(0, engineFile.end);
        fsize = engineFile.tellg();
        engineFile.seekg(0, engineFile.beg);
        std::vector<char> engineString(fsize);
        engineFile.read(engineString.data(), fsize);
        if (engineString.size() == 0)
        {
    
    
            std::cout << "Failed getting serialized engine!" << std::endl;
            exit(-1);
        }
        std::cout << "Succeeded getting serialized engine." << std::endl;
        
        nvinfer1::IRuntime* runtime {
    
    createInferRuntime(gLogger)};
        // safe::IRuntime *runtime {safe::createInferRuntime(gLogger)}; // 使用 safe runtime
        engine_ = runtime->deserializeCudaEngine(engineString.data(), fsize);
        if (engine_ == nullptr)
        {
    
    
            std::cerr << "Failed loading engine." << std::endl;
            exit(-1);
        }
        std::cerr << "Succeeded loading engine." << std::endl;
        engineFile.close();
       
    }
    // inference
    // Step 2
    int inputSize = 1 * 3 * 1152 * 1152;
    int outputSize = 1 * 144 * 144;
    std::vector<float> inputBuffer(inputSize);
    std::vector<int32_t> outputBuffer(outputSize);
    // Step 3
    cudaSetDevice(0);
    cudaFree(0);
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    void *d_inputBuffer = nullptr;
    cudaMalloc(&d_inputBuffer, inputSize*sizeof(float));
    void *d_outputBuffer = nullptr;
    cudaMalloc(&d_outputBuffer, outputSize*sizeof(int32_t));
    // Step 4
    nvinfer1::IExecutionContext* context = engine_->createExecutionContext();
    if (!context)
    {
    
    
        std::cerr << "Failed to create execution context" << std::endl;
        return 1;
    }
    // Step 5
    for (int i = 0; i < inputSize; ++i) {
    
    
        inputBuffer[i] = i % 255;
    }
    cudaMemcpyAsync(d_inputBuffer, inputBuffer.data(), inputSize * sizeof(float), cudaMemcpyHostToDevice, stream);

    // Step 6
    void *buffers[] = {
    
    d_inputBuffer, d_outputBuffer};
    context->enqueueV2(buffers, stream, nullptr);
    // Step 7
    cudaMemcpyAsync(outputBuffer.data(), d_outputBuffer, outputSize * sizeof(int32_t), cudaMemcpyDeviceToHost, stream);
    for (size_t i = 0; i < 30; i++)
    {
    
    
        if (outputBuffer[i] > 0)   /* code */
        {
    
    
            std::cerr << outputBuffer[i] << std::endl;
        }
    }
    
    context->destroy();
    engine_->destroy();
    cudaFree(d_inputBuffer);
    cudaFree(d_outputBuffer);
}

猜你喜欢

转载自blog.csdn.net/qq_39506862/article/details/131023189