TensorRT中SampleOnnxMNIST.cpp分析

#从onnx模型到TensorRT的工作流程

  • 将onnx模型转化为TensorRT网络

  • 构建引擎

  • 使用生成的Tensor网络进行推断

##类SampleONNXMNIST的构成
class SampleOnnxMNIST
{
template
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
//定义一个unique_ptr(uniqut_ptr是一种对资源具有排他性拥有权的智能指针,即一个对象资源只能同时被一个unique_ptr指向,在离开作用域时自动调用Delter释放内存)的别名为SampleUniquePtr,专门管理Sample的,其中解析时的Deleter使用的为samplesCommon::InferDeleter

public:
    SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
        : mParams(params)
        , mEngine(nullptr)
    {
    }
    //上面里的: mParams(params), mEngine(nullptr)是指初始化列表,列表中有两个类成员分别为mParams和mEngine,前者值为初始化类SampleOnnxMNIST时传参params,后者则初始化为空指针
    //!
    //! \构建引擎
    //!
    bool build();

    //!
    //! \使用生成的Tensor网络进行推断
    //!
    bool infer();

private:
    samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample.

    nvinfer1::Dims mInputDims;  //!< input输入维数
    nvinfer1::Dims mOutputDims; //!< output输入维数
    int mNumber{0};             //!< The number to classify

    std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< 转换后的TensorRT网络

    //!
    //! \brief 将onnx模型转化为TensorRT网络
    //!
    bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
        SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
        SampleUniquePtr<nvonnxparser::IParser>& parser);

    //!
    //! \brief 读取并缓存input到buffer
    //!
    bool processInput(const samplesCommon::BufferManager& buffers);

    //!
    //! \brief Classifies digits and verify result
    //!
    bool verifyOutput(const samplesCommon::BufferManager& buffers);
};

##main函数

int main(int argc, char** argv)
{   //输入参数解析
    samplesCommon::Args args;
    bool argsOK = samplesCommon::parseArgs(args, argc, argv);
    if (!argsOK)
    {
        sample::gLogError << "Invalid arguments" << std::endl;
        printHelpInfo();
        return EXIT_FAILURE;
    }
    if (args.help)
    {
        printHelpInfo();
        return EXIT_SUCCESS;
    }
    
    //定义一个Logger用于记录和打印输出
    auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
    //开始转化    
    sample::gLogger.reportTestStart(sampleTest);

    //使用initializeSampleParams解析并传入参数,初始化SampleOnnxMNIST sample
    SampleOnnxMNIST sample(initializeSampleParams(args));

    sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;

    //构建TensorRT网络
    if (!sample.build())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    //推断
    if (!sample.infer())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    //结束
    return sample::gLogger.reportPass(sampleTest);
}

主要工作为几个阶段:

  • SampleOnnxMNIST sample(initializeSampleParams(args));
  • sample.build();
  • sample.infer();
  1. 其中第一个执行的SampleOnnxMNIST sample(initializeSampleParams(args)):

     samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
     {
         samplesCommon::OnnxSampleParams params;
         if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths
         {
             params.dataDirs.push_back("data/mnist/");
             params.dataDirs.push_back("data/samples/mnist/");
         }
         else //!< Use the data directory provided by the user
         {
             params.dataDirs = args.dataDirs;
         }
         params.onnxFileName = "mnist.onnx";
         params.inputTensorNames.push_back("Input3");
         params.outputTensorNames.push_back("Plus214_Output_0");
         params.dlaCore = args.useDLACore;
         params.int8 = args.runInInt8;
         params.fp16 = args.runInFp16;
    
         return params;
     }
    

    为传入工作路径、指定onnx文件名称、设定输入输出Tensor的名字、设定使用的dlaCore、设定参数中的执行精度,最后返回结构体为samplesCommon::OnnxSampleParams的参数params。

    然后就实例化对象SampleOnnxMNIST sample,将类成员mParams值初始化为params,mEngine初始化为空指针


  1. sample.build();

     bool SampleOnnxMNIST::build()
     {
         auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
         if (!builder)
         {
             return false;
         }
    
         const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
         auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
         if (!network)
         {
             return false;
         }
    
         auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
         if (!config)
         {
             return false;
         }
    
         auto parser
             = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
         if (!parser)
         {
             return false;
         }
    
         auto constructed = constructNetwork(builder, network, config, parser);
         if (!constructed)
         {
             return false;
         }
    
         mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
             builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
         if (!mEngine)
         {
             return false;
         }
    
         assert(network->getNbInputs() == 1);
         mInputDims = network->getInput(0)->getDimensions();
         assert(mInputDims.nbDims == 4);
    
         assert(network->getNbOutputs() == 1);
         mOutputDims = network->getOutput(0)->getDimensions();
         assert(mOutputDims.nbDims == 2);
    
         return true;
     }
    

    其中核心为auto constructed = constructNetwork(builder, network, config, parser);

     //!
     //! \brief 使用Onnx parser来创建网络,并标记输出层 
     //!  输入参数中的network指向输出的生成的onnx网络     
     //! \param network Pointer to the network that will be populated with the Onnx MNIST network
     //!
     //! \param builder Pointer to the engine builder
     //!
     bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
         SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
         SampleUniquePtr<nvonnxparser::IParser>& parser)
     {
         auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
             static_cast<int>(sample::gLogger.getReportableSeverity()));
         if (!parsed)
         {
             return false;
         }
    
         config->setMaxWorkspaceSize(16_MiB);
         if (mParams.fp16)
         {
             config->setFlag(BuilderFlag::kFP16);
         }
         if (mParams.int8)
         {
             config->setFlag(BuilderFlag::kINT8);
             samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
         }
    
         samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
    
         return true;
     }
    

    而后调用builder->buildEngineWithConfig,对network进行build操作,根据在前面constructNetwork中设定了的config来生成TensorRT的网络。
    最后这段

     assert(network->getNbInputs() == 1);
     mInputDims = network->getInput(0)->getDimensions();
     assert(mInputDims.nbDims == 4);
    
     assert(network->getNbOutputs() == 1);
     mOutputDims = network->getOutput(0)->getDimensions();
     assert(mOutputDims.nbDims == 2);
    

    需要根据具体网络的输入输出Tensor的形状进行调整。


  1. sample.infer();
    进行TensorRT预测,先申请缓存,然后设定输入,最后执行engine

     bool SampleOnnxMNIST::infer()
     {
         // Create RAII buffer manager object
         samplesCommon::BufferManager buffers(mEngine);
    
         auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
         if (!context)
         {
             return false;
         }
    
         // Read the input data into the managed buffers
         assert(mParams.inputTensorNames.size() == 1);
         if (!processInput(buffers))
         {
             return false;
         }
    
         // Memcpy from host input buffers to device input buffers
         buffers.copyInputToDevice();
    
         bool status = context->executeV2(buffers.getDeviceBindings().data());
         if (!status)
         {
             return false;
         }
    
         // Memcpy from device output buffers to host output buffers
         buffers.copyOutputToHost();
    
         // Verify results
         if (!verifyOutput(buffers))
         {
             return false;
         }
    
         return true;
     }
    

    其中核心的:

    • processInput(buffers)预处理

    • context->executeV2(buffers.getDeviceBindings().data())执行预测

    • verifyOutput(buffers)后处理

        bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers)
        {
            const int inputH = mInputDims.d[2];
            const int inputW = mInputDims.d[3];
      
            // Read a random digit file
            srand(unsigned(time(nullptr)));
            std::vector<uint8_t> fileData(inputH * inputW);
            mNumber = rand() % 10;
            readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);
      
            // 使用ascii码在终端拼图片(实际应用不必)
            sample::gLogInfo << "Input:" << std::endl;
            for (int i = 0; i < inputH * inputW; i++)
            {
                sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
            }
            sample::gLogInfo << std::endl;
      
            float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));
            for (int i = 0; i < inputH * inputW; i++)
            {
                hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);
            }
      
            return true;
        }
      
      
        bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
        {
            const int outputSize = mOutputDims.d[1];
            float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));
            float val{0.0f};
            int idx{0};
      
            // Calculate Softmax
            float sum{0.0f};
            for (int i = 0; i < outputSize; i++)
            {
                output[i] = exp(output[i]);
                sum += output[i];
            }
      
            sample::gLogInfo << "Output:" << std::endl;
            for (int i = 0; i < outputSize; i++)
            {
                output[i] /= sum;
                val = std::max(val, output[i]);
                if (val == output[i])
                {
                    idx = i;
                }
      
                sample::gLogInfo << " Prob " << i << "  " << std::fixed << std::setw(5) << std::setprecision(4) << output[i]
                                << " "
                                << "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*')
                                << std::endl;
            }
            sample::gLogInfo << std::endl;
      
            return idx == mNumber && val > 0.9f;
        }
      

猜你喜欢

转载自blog.csdn.net/Johnson_star/article/details/107692357
cpp
今日推荐