#从onnx模型到TensorRT的工作流程
-
将onnx模型转化为TensorRT网络
-
构建引擎
-
使用生成的Tensor网络进行推断
##类SampleONNXMNIST的构成
class SampleOnnxMNIST
{
template
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
//定义一个unique_ptr(uniqut_ptr是一种对资源具有排他性拥有权的智能指针,即一个对象资源只能同时被一个unique_ptr指向,在离开作用域时自动调用Delter释放内存)的别名为SampleUniquePtr,专门管理Sample的,其中解析时的Deleter使用的为samplesCommon::InferDeleter
public:
SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params)
: mParams(params)
, mEngine(nullptr)
{
}
//上面里的: mParams(params), mEngine(nullptr)是指初始化列表,列表中有两个类成员分别为mParams和mEngine,前者值为初始化类SampleOnnxMNIST时传参params,后者则初始化为空指针
//!
//! \构建引擎
//!
bool build();
//!
//! \使用生成的Tensor网络进行推断
//!
bool infer();
private:
samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample.
nvinfer1::Dims mInputDims; //!< input输入维数
nvinfer1::Dims mOutputDims; //!< output输入维数
int mNumber{0}; //!< The number to classify
std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< 转换后的TensorRT网络
//!
//! \brief 将onnx模型转化为TensorRT网络
//!
bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
SampleUniquePtr<nvonnxparser::IParser>& parser);
//!
//! \brief 读取并缓存input到buffer
//!
bool processInput(const samplesCommon::BufferManager& buffers);
//!
//! \brief Classifies digits and verify result
//!
bool verifyOutput(const samplesCommon::BufferManager& buffers);
};
##main函数
int main(int argc, char** argv)
{ //输入参数解析
samplesCommon::Args args;
bool argsOK = samplesCommon::parseArgs(args, argc, argv);
if (!argsOK)
{
sample::gLogError << "Invalid arguments" << std::endl;
printHelpInfo();
return EXIT_FAILURE;
}
if (args.help)
{
printHelpInfo();
return EXIT_SUCCESS;
}
//定义一个Logger用于记录和打印输出
auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
//开始转化
sample::gLogger.reportTestStart(sampleTest);
//使用initializeSampleParams解析并传入参数,初始化SampleOnnxMNIST sample
SampleOnnxMNIST sample(initializeSampleParams(args));
sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;
//构建TensorRT网络
if (!sample.build())
{
return sample::gLogger.reportFail(sampleTest);
}
//推断
if (!sample.infer())
{
return sample::gLogger.reportFail(sampleTest);
}
//结束
return sample::gLogger.reportPass(sampleTest);
}
主要工作为几个阶段:
- SampleOnnxMNIST sample(initializeSampleParams(args));
- sample.build();
- sample.infer();
-
其中第一个执行的SampleOnnxMNIST sample(initializeSampleParams(args)):
samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args) { samplesCommon::OnnxSampleParams params; if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths { params.dataDirs.push_back("data/mnist/"); params.dataDirs.push_back("data/samples/mnist/"); } else //!< Use the data directory provided by the user { params.dataDirs = args.dataDirs; } params.onnxFileName = "mnist.onnx"; params.inputTensorNames.push_back("Input3"); params.outputTensorNames.push_back("Plus214_Output_0"); params.dlaCore = args.useDLACore; params.int8 = args.runInInt8; params.fp16 = args.runInFp16; return params; }
为传入工作路径、指定onnx文件名称、设定输入输出Tensor的名字、设定使用的dlaCore、设定参数中的执行精度,最后返回结构体为samplesCommon::OnnxSampleParams的参数params。
然后就实例化对象SampleOnnxMNIST sample,将类成员mParams值初始化为params,mEngine初始化为空指针
-
sample.build();
bool SampleOnnxMNIST::build() { auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())); if (!builder) { return false; } const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch)); if (!network) { return false; } auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig()); if (!config) { return false; } auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger())); if (!parser) { return false; } auto constructed = constructNetwork(builder, network, config, parser); if (!constructed) { return false; } mEngine = std::shared_ptr<nvinfer1::ICudaEngine>( builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter()); if (!mEngine) { return false; } assert(network->getNbInputs() == 1); mInputDims = network->getInput(0)->getDimensions(); assert(mInputDims.nbDims == 4); assert(network->getNbOutputs() == 1); mOutputDims = network->getOutput(0)->getDimensions(); assert(mOutputDims.nbDims == 2); return true; }
其中核心为auto constructed = constructNetwork(builder, network, config, parser);
//! //! \brief 使用Onnx parser来创建网络,并标记输出层 //! 输入参数中的network指向输出的生成的onnx网络 //! \param network Pointer to the network that will be populated with the Onnx MNIST network //! //! \param builder Pointer to the engine builder //! bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config, SampleUniquePtr<nvonnxparser::IParser>& parser) { auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast<int>(sample::gLogger.getReportableSeverity())); if (!parsed) { return false; } config->setMaxWorkspaceSize(16_MiB); if (mParams.fp16) { config->setFlag(BuilderFlag::kFP16); } if (mParams.int8) { config->setFlag(BuilderFlag::kINT8); samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f); } samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore); return true; }
而后调用builder->buildEngineWithConfig,对network进行build操作,根据在前面constructNetwork中设定了的config来生成TensorRT的网络。
最后这段assert(network->getNbInputs() == 1); mInputDims = network->getInput(0)->getDimensions(); assert(mInputDims.nbDims == 4); assert(network->getNbOutputs() == 1); mOutputDims = network->getOutput(0)->getDimensions(); assert(mOutputDims.nbDims == 2);
需要根据具体网络的输入输出Tensor的形状进行调整。
-
sample.infer();
进行TensorRT预测,先申请缓存,然后设定输入,最后执行enginebool SampleOnnxMNIST::infer() { // Create RAII buffer manager object samplesCommon::BufferManager buffers(mEngine); auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext()); if (!context) { return false; } // Read the input data into the managed buffers assert(mParams.inputTensorNames.size() == 1); if (!processInput(buffers)) { return false; } // Memcpy from host input buffers to device input buffers buffers.copyInputToDevice(); bool status = context->executeV2(buffers.getDeviceBindings().data()); if (!status) { return false; } // Memcpy from device output buffers to host output buffers buffers.copyOutputToHost(); // Verify results if (!verifyOutput(buffers)) { return false; } return true; }
其中核心的:
-
processInput(buffers)预处理
-
context->executeV2(buffers.getDeviceBindings().data())执行预测
-
verifyOutput(buffers)后处理
bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers) { const int inputH = mInputDims.d[2]; const int inputW = mInputDims.d[3]; // Read a random digit file srand(unsigned(time(nullptr))); std::vector<uint8_t> fileData(inputH * inputW); mNumber = rand() % 10; readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW); // 使用ascii码在终端拼图片(实际应用不必) sample::gLogInfo << "Input:" << std::endl; for (int i = 0; i < inputH * inputW; i++) { sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n"); } sample::gLogInfo << std::endl; float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0])); for (int i = 0; i < inputH * inputW; i++) { hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0); } return true; } bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers) { const int outputSize = mOutputDims.d[1]; float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0])); float val{0.0f}; int idx{0}; // Calculate Softmax float sum{0.0f}; for (int i = 0; i < outputSize; i++) { output[i] = exp(output[i]); sum += output[i]; } sample::gLogInfo << "Output:" << std::endl; for (int i = 0; i < outputSize; i++) { output[i] /= sum; val = std::max(val, output[i]); if (val == output[i]) { idx = i; } sample::gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i] << " " << "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*') << std::endl; } sample::gLogInfo << std::endl; return idx == mNumber && val > 0.9f; }
-