tensorRT部署之 代码实现 onnx转engine/trt模型

tensorRT部署之 代码实现 onnx转engine/trt模型

  • 前提已经装好显卡驱动、cuda、cudnn、以及tensorRT
  • 下面将给出Python、C++两种转换方式

1. C++实现

  • 项目属性配置好CUDA、tensoeRT库
  • 通常在实际应用中会直接读取onnx模型进行判断,如果对应路径已经存在engine模型,将直接通过tensorrt读入engine,如果没有,则对onnx进行编译生成engine模型后在进行读入
  • TensorRT在线加载模型,并序列化保存支持动态batch的引擎,实现源码可参考 TextandCode
  • 一篇超级详细的onnx基础教程(非常好):TextandCode
  • 代码实现:
#include <iostream>
#include <fstream>
#include "NvInfer.h"
#include "NvOnnxParser.h"
 
// 实例化记录器界面。捕获所有警告消息,但忽略信息性消息
class Logger : public nvinfer1::ILogger
{
    
    
    void log(Severity severity, const char* msg) noexcept override
    {
    
    
        // suppress info-level messages
        if (severity <= Severity::kWARNING)
            std::cout << msg << std::endl;
    }
} logger;
 
 
void ONNX2TensorRT(const char* ONNX_file, std::string save_ngine)
{
    
    
    // 1.创建构建器的实例
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
 
    // 2.创建网络定义
    uint32_t flag = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(flag);
 
    // 3.创建一个 ONNX 解析器来填充网络
    nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);
 
    // 4.读取模型文件并处理任何错误
    parser->parseFromFile(ONNX_file, static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));
    for (int32_t i = 0; i < parser->getNbErrors(); ++i)
    {
    
    
        std::cout << parser->getError(i)->desc() << std::endl;
    }
 
    // 5.创建一个构建配置,指定 TensorRT 应该如何优化模型
    nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
 
    // 6.设置属性来控制 TensorRT 如何优化网络
    // 设置内存池的空间
    config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 * (1 << 20));
    // 设置低精度   注释掉为FP32
    if (builder->platformHasFastFp16())
    {
    
    
        config->setFlag(nvinfer1::BuilderFlag::kFP16);
    }
 
    // 7.指定配置后,构建引擎
    nvinfer1::IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);
 
    // 8.保存TensorRT模型
    std::ofstream p(save_ngine, std::ios::binary);
    p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());
 
    // 9.序列化引擎包含权重的必要副本,因此不再需要解析器、网络定义、构建器配置和构建器,可以安全地删除
    delete parser;
    delete network;
    delete config;
    delete builder;
 
    // 10.将引擎保存到磁盘,并且可以删除它被序列化到的缓冲区
    delete serializedModel;
}
 
 
void exportONNX(const char* ONNX_file, std::string save_ngine)
{
    
    
    std::ifstream file(ONNX_file, std::ios::binary);
    if (!file.good())
    {
    
    
        std::cout << "Load ONNX file failed! No file found from:" << ONNX_file << std::endl;
        return ;
    }
 
    std::cout << "Load ONNX file from: " << ONNX_file << std::endl;
    std::cout << "Starting export ..." << std::endl;
 
    ONNX2TensorRT(ONNX_file, save_ngine);
 
    std::cout << "Export success, saved as: " << save_ngine << std::endl;
 
}
 
 
int main(int argc, char** argv)
{
    
    
    // 输入信息
    const char* ONNX_file  = "../weights/test.onnx";
    std::string save_ngine = "../weights/test.engine";
 
    exportONNX(ONNX_file, save_ngine);
 
    return 0;
}
 

猜你喜欢

转载自blog.csdn.net/yohnyang/article/details/131342398