一个简单的tensorRT mnist推理案例,模型采用代码构建

TensorRT是NVIDIA的一个深度神经网络推理引擎,可以对深度学习模型进行优化和部署。本程序中,使用了TensorRT来加载一个已经训练好的模型并进行推理。

TRTLogger是一个日志记录类,用于记录TensorRT的运行日志。

Matrix是一个矩阵结构体,用于存储模型权重和输入输出数据。Model是一个模型结构体,用于存储加载的模型。
print_image函数用于将图像的像素值打印出来,方便调试和查看。load_file函数用于从文件中加载数据,包括模型权重和输入图像数据。load_model函数用于加载模型权重,其中模型权重的文件名按照"[index].weight"的格式命名,index从0开始递增。模型权重的形状是预先定义好的,存储在weight_shapes数组中,其中weight_shapes[i][0]表示第i层权重的行数,weight_shapes[i][1]表示第i层权重的列数。

这些函数都是为了方便程序的编写和调试,可以根据具体的应用场景进行修改和扩展。
它包括将BMP格式的图像数据转换为适合输入神经网络的矩阵的函数,以及将神经网络的权重转换为适合与TensorRT一起使用的格式的函数。

do_trt_build_engine函数使用TensorRT API构建神经网络,然后将结果引擎序列化到文件中。

do_trt_inference函数从文件中加载序列化的引擎,然后使用引擎在一组输入图像上执行推理。对于每个输入图像,它将BMP数据转换为矩阵,将矩阵复制到GPU,使用引擎进行推理,然后将输出概率值复制回CPU以供显示。
它首先调用load_model函数加载训练好的模型,并打印出每个权重矩阵的大小。

接下来,它调用do_trt_build_engine函数将模型转换为TensorRT引擎,并将引擎保存到文件mnist.trtmodel中。

最后,它调用do_trt_inference函数对一组输入图像执行推理,并显示每个图像的预测结果和置信度。

在推理完成后,它打印出一条消息表示程序运行完成,并返回0表示程序正常退出。

// tensorRT include
#include <NvInfer.h>
#include <NvInferRuntime.h>

// cuda include
#include <cuda_runtime.h>

// system include
#include <stdio.h>
#include <string.h>
#include <math.h>

#include <vector>
#include <string>
#include <fstream>
#include <algorithm>

using namespace std;


#define SIMLOG(type, ...)                        \
    do{
      
                                                \
        printf("[%s:%d]%s: ", __FILE__, __LINE__, type); \
        printf(__VA_ARGS__);                     \
        printf("\n");                            \
    }while(0)

#define INFO(...)   SIMLOG("info", __VA_ARGS__)

inline const char* severity_string(nvinfer1::ILogger::Severity t){
    
    
    switch(t){
    
    
        case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: return "internal_error";
        case nvinfer1::ILogger::Severity::kERROR:   return "error";
        case nvinfer1::ILogger::Severity::kWARNING: return "warning";
        case nvinfer1::ILogger::Severity::kINFO:    return "info";
        case nvinfer1::ILogger::Severity::kVERBOSE: return "verbose";
        default: return "unknow";
    }
}

class TRTLogger : public nvinfer1::ILogger{
    
    
public:
    virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override{
    
    
        if(severity <= Severity::kINFO){
    
    
            SIMLOG(severity_string(severity), "%s", msg);
        }
    }
};

struct Matrix{
    
    
    vector<float> data;
    int rows = 0, cols = 0;

    void resize(int rows, int cols){
    
    
        this->rows = rows;
        this->cols = cols;
        this->data.resize(rows * cols * sizeof(float));
    }

    bool empty() const{
    
    return data.empty();}
    int size() const{
    
     return rows * cols; }
    float* ptr() const{
    
    return (float*)this->data.data();}
};

struct Model{
    
    
    vector<Matrix> weights;
};

void print_image(const vector<unsigned char>& a, int rows, int cols, const char* format = "%3d"){
    
    
    INFO("Matrix[%p], %d x %d", &a, rows, cols);

    char fmt[20];
    sprintf(fmt, "%s,", format);

    for(int i = 0; i < rows; ++i){
    
    

        printf("row[%02d]: ", i);
        for(int j = 0; j < cols; ++j){
    
    
            int index = (rows - i - 1) * cols + j;
            printf(fmt, a.data()[index * 3 + 0]);
        }
        printf("\n");
    }
}

vector<unsigned char> load_file(const string& file){
    
    
    ifstream in(file, ios::in | ios::binary);
    if (!in.is_open())
        return {
    
    };

    in.seekg(0, ios::end);
    size_t length = in.tellg();

    std::vector<uint8_t> data;
    if (length > 0){
    
    
        in.seekg(0, ios::beg);
        data.resize(length);

        in.read((char*)&data[0], length);
    }
    in.close();
    return data;
}

bool load_model(Model& model){
    
    
    model.weights.resize(4);

    const int weight_shapes[][2] = {
    
    
        {
    
    1024, 784},
        {
    
    1024, 1},
        {
    
    10, 1024},
        {
    
    10, 1}
    };

    for(int i = 0; i < model.weights.size(); ++i){
    
    
        char weight_name[100];
        sprintf(weight_name, "%d.weight", i);

        auto data = load_file(weight_name);
        if(data.empty()){
    
    
            INFO("Load %s failed.", weight_name);
            return false;
        }

        auto& w = model.weights[i];
        int rows = weight_shapes[i][0];
        int cols = weight_shapes[i][1];
        if(data.size() != rows * cols * sizeof(float)){
    
    
            INFO("Invalid weight file: %s", weight_name);
            return false;
        }

        w.resize(rows, cols);
        memcpy(w.ptr(), data.data(), data.size());
    }
    return true;
}

Matrix bmp_data_to_normalize_matrix(const vector<unsigned char>& data){
    
    

    Matrix output;
    const int std_w = 28;
    const int std_h = 28;
    if(data.size() != std_w * std_h * 3){
    
    
        INFO("Invalid bmp file, must be %d x %d @ rgb 3 channels image", std_w, std_h);
        return output;
    }
    output.resize(1, std_w * std_h);

    const unsigned char* begin_ptr = data.data();
    float* output_ptr = output.ptr();
    for(int i = 0; i < std_h; ++i){
    
    
        const unsigned char* image_row_ptr = begin_ptr + (std_h - i - 1) * std_w * 3;
        float* output_row_ptr = output_ptr + i * std_w;
        for(int j = 0; j < std_w; ++j){
    
    
            // normalize
            output_row_ptr[j] = (image_row_ptr[j * 3 + 0] / 255.0f - 0.1307f) / 0.3081f;;
        }
    }
    return output;
}

nvinfer1::Weights model_weights_to_trt_weights(const Matrix& model_weights){
    
    

    nvinfer1::Weights output;
    output.type = nvinfer1::DataType::kFLOAT;
    output.values = model_weights.ptr();
    output.count = model_weights.size();
    return output;
}

TRTLogger logger;
void do_trt_build_engine(const Model& model, const string& save_file){
    
    

    /*
        Network is:

        image
          |
        linear (fully connected)  input = 784, output = 1024, bias = True
          |
        relu
          |
        linear (fully connected)  input = 1024, output = 10, bias = True
          |
        sigmoid
          |
        prob
    */

    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
    nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);

    nvinfer1::ITensor* input = network->addInput("image", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(1, 784, 1, 1));
    nvinfer1::Weights layer1_weight = model_weights_to_trt_weights(model.weights[0]);
    nvinfer1::Weights layer1_bias = model_weights_to_trt_weights(model.weights[1]);
    auto layer1 = network->addFullyConnected(*input, model.weights[0].rows, layer1_weight, layer1_bias);
    auto relu1 = network->addActivation(*layer1->getOutput(0), nvinfer1::ActivationType::kRELU);

    nvinfer1::Weights layer2_weight = model_weights_to_trt_weights(model.weights[2]);
    nvinfer1::Weights layer2_bias = model_weights_to_trt_weights(model.weights[3]);
    auto layer2 = network->addFullyConnected(*relu1->getOutput(0), model.weights[2].rows, layer2_weight, layer2_bias);
    auto prob = network->addActivation(*layer2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
    
    network->markOutput(*prob->getOutput(0));
    config->setMaxWorkspaceSize(1 << 28);
    builder->setMaxBatchSize(1);

    nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
    if(engine == nullptr){
    
    
        INFO("Build engine failed.");
        return;
    }

    nvinfer1::IHostMemory* model_data = engine->serialize();
    ofstream outf(save_file, ios::binary | ios::out);
    if(outf.is_open()){
    
    
        outf.write((const char*)model_data->data(), model_data->size());
        outf.close();
    }else{
    
    
        INFO("Open %s failed", save_file.c_str());
    }

    model_data->destroy();
    engine->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();
}

void do_trt_inference(const string& model_file){
    
    

    auto engine_data = load_file(model_file);
    if(engine_data.empty()){
    
    
        INFO("engine_data is empty");
        return;
    }

    nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
    nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
    if(engine == nullptr){
    
    
        INFO("Deserialize cuda engine failed.");
        return;
    }

    nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();
    cudaStream_t stream = nullptr;
    cudaStreamCreate(&stream);

    const char* image_list[] = {
    
    "5.bmp", "6.bmp"};
    int num_image = sizeof(image_list) / sizeof(image_list[0]);
    const int num_classes = 10;
    for(int i = 0; i < num_image; ++i){
    
    

        const int bmp_file_head_size = 54;
        auto file_name  = image_list[i];
        auto image_data = load_file(file_name);
        if(image_data.empty() || image_data.size() != bmp_file_head_size + 28*28*3){
    
    
            INFO("Load image failed: %s", file_name);
            continue;
        }

        image_data.erase(image_data.begin(), image_data.begin() + bmp_file_head_size);
        auto image = bmp_data_to_normalize_matrix(image_data);
        float* image_device_ptr = nullptr;
        cudaMalloc(&image_device_ptr, image.size() * sizeof(float));
        cudaMemcpyAsync(image_device_ptr, image.ptr(), image.size() * sizeof(float), cudaMemcpyHostToDevice, stream);

        float* output_device_ptr = nullptr;
        cudaMalloc(&output_device_ptr, num_classes * sizeof(float));

        float* bindings[] = {
    
    image_device_ptr, output_device_ptr};
        bool success      = execution_context->enqueueV2((void**)bindings, stream, nullptr);
        float predict_proba[num_classes];
        cudaMemcpyAsync(predict_proba, output_device_ptr, num_classes * sizeof(float), cudaMemcpyDeviceToHost, stream);
        cudaStreamSynchronize(stream);

        // release memory
        cudaFree(image_device_ptr);
        cudaFree(output_device_ptr);

        int predict_label  = std::max_element(predict_proba, predict_proba + num_classes) - predict_proba;
        float predict_prob = predict_proba[predict_label];

        print_image(image_data, 28, 28);
        INFO("image matrix: %d x %d", image.rows, image.cols);
        INFO("%s predict: %d, confidence: %f", file_name, predict_label, predict_prob);

        printf("Press 'Enter' to next, Press 'q' to quit: ");
        int c = getchar();
        if(c == 'q')
            break;
    }

    INFO("Clean memory");
    cudaStreamDestroy(stream);
    execution_context->destroy();
    engine->destroy();
    runtime->destroy();
}   

int main(){
    
    

    Model model;
    if(!load_model(model))
        return 0;

    for(int i = 0; i < model.weights.size(); ++i){
    
    
        INFO("weight.%d shape = %d x %d", i, model.weights[i].rows, model.weights[i].cols);
    }

    auto trtmodel = "mnist.trtmodel";
    do_trt_build_engine(model, trtmodel);
    do_trt_inference(trtmodel);
    INFO("done.");
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_44089890/article/details/130003098