[Notas de estudio de CV] análisis de código línea por línea tensorrtx-yolov5

1. Introducción

TensorRTx (en adelante, trtx) es una biblioteca de código abierto muy popular que usa API para construir estructuras de red para lograr la aceleración trt. El autor mencionó por qué no usar el analizador ONNX para la aceleración trt, pero usar la API de nivel más bajo para construir la aceleración trt. Las razones son las siguientes:

  • Flexible facilita la modificación de cualquier capa del modelo, como eliminar, agregar, reemplazar, etc.
  • Depurable puede obtener fácilmente los resultados de una determinada capa en el medio del modelo.
  • La oportunidad de aprender puede proporcionar una mayor comprensión de la estructura del modelo.

Aunque el método onnx2trt actualmente no causa problemas en la mayoría de los casos, bajo trtx podemos dominar los principios y códigos de nivel inferior, lo que es beneficioso para nuestra implementación y optimización del modelo. A continuación se utilizará el ejemplo de yolov5s en el marco trtx para analizar cómo funciona trtx línea por línea.

Enlace del proyecto TensorRTx: https://github.com/wang-xinyu/tensorrtx.

2. Análisis de pasos

En trtx, el proceso de aceleración de un modelo se puede dividir en dos pasos:

  • Extraer los parámetros del modelo de pytorch wts
  • Utilice la API subyacente trt para construir la estructura de la red y completar los parámetros en wts en la red.
2.1、get_wts.py

Primero, debe extraer los parámetros del modelo en pytorch. Los parámetros del modelo en pytorch existen en el formato de blob en caffe. Cada operación tiene un nombre, longitud de datos y datos correspondientes.

for k, v in model.state_dict().items():
    # k-> blob的名字
    vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度
    f.write('{} {} '.format(k, len(vr)))
    for vv in vr:
        f.write(' ')
        f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制
        f.write('\n')

Al cargar get_wts.py, puede obtener los parámetros del modelo, incluido yolov5s.pth. Abra yolov5s.wts como se muestra a continuación:

Insertar descripción de la imagen aquí

El 351 en la primera línea es el número total de blobs, el model.0.conv.weight en la segunda línea es el nombre del primer blob, 3456 representa la longitud de los datos del blob y 3a198000 3ca58000... son los parámetros reales.

Después de obtener los parámetros anteriores, puede acelerar en modo trtx.

2.2 Construir motor

Antes de usar wts para convertir a motor, debe tener muy clara la estructura de red del modelo. Los estudiantes que no estén seguros pueden consultar el diagrama de estructura de red de yolov5 de Little Mung Bean de Girasol . Después de comprender la estructura de red de yolov5, puede comenzar a usar la API de trt para crear un modelo de red. El código para construir el modelo se encuentra en la función build_det_engine en model.cpp. Este artículo dibuja el proceso de código directamente en el diagrama de estructura de red de yolov5. Puede verificar directamente el código y el diagrama.
Insertar descripción de la imagen aquí

//yolov5_det.cpp
viod serialize_engine(...){
    
    
	if (is_p6) {
    
    
        ...
	} else {
    
    
        // 以yolov5s为例
        engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);
  	}
    // 序列化
    IHostMemory* serialized_engine = engine->serialize();
    std::ofstream p(engine_name, std::ios::binary);
    // 写到文件中
    p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());

}

modelo.cpp

// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {
    
    
    int32_t count;  // wts文件第一行,共有351个blob
  	input >> count;
    //每一行是一个blob,模型名称 + 数据长度 + 参数
    while (count--) {
    
    
        // 一个blob的参数
     	Weights wt{
    
     DataType::kFLOAT, nullptr, 0 };
        uint32_t size;  //blob 数据长度
        std::string name; // blob 数据名字
        for (uint32_t x = 0, y = size; x < y; ++x) {
    
    
      		input >> std::hex >> val[x];  // 将数据转化成十进制,并放到val中
    	}
        // 每个blob名字对应一个wt
        weightMap[name] = wt;
    }
}


ICudaEngine* build_det_engine(){
    
    
   // 初始化网络结构
   INetworkDefinition* network = builder->createNetworkV2(0U);
   // 定义模型输入
   ITensor* data = network->addInput(kInputTensorName, dt, Dims3{
    
     3, kInputH, kInputW });
   // 加载pytorch模型中的参数
   std::map<std::string, Weights> weightMap = loadWeights(wts_name);
    
   // 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图
 
   // 增加yolo后处理decode模块,使用了plugin
   auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{
    
    det0, det1, det2});
   network->markOutput(*yolo->getOutput(0));  //将plugin的输出设置为模型的最后输出(decode)
    
   #if defined(USE_FP16)
  	// FP16
	config->setFlag(BuilderFlag::kFP16);
   #elif defined(USE_INT8)
    // INT8 量化
    std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
    assert(builder->platformHasFastInt8());
    config->setFlag(BuilderFlag::kINT8);
    Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);
      config->setInt8Calibrator(calibrator);
    #endif
    // 根据网络结构来生成engine
    ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
	return engine;
}
3、complemento

También estoy aprendiendo sobre el complemento. La siguiente es mi comprensión básica del complemento mientras aprendo el código trtx-yolo5. El autor original agregó un complemento de decodificación de modelo detrás del modelo para obtener el bbox en cada capa de entidades. El código de llamada está en model.cpp.

auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{
    
    det0, det1, det2});

static IPluginV2Layer* addYoLoLayer(...){
    
    
    // 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错
    auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");
    
    // plugin的数据
    PluginField plugin_fields[2];
    int netinfo[5] = {
    
    kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation};  //维度数据
  	plugin_fields[0].data = netinfo;  
  	plugin_fields[0].length = 5; 
  	plugin_fields[0].name = "netinfo";
  	plugin_fields[0].type = PluginFieldType::kFLOAT32;
    
    // 所有plugin的参数
    PluginFieldCollection plugin_data;
  	plugin_data.nbFields = 2;
  	plugin_data.fields = plugin_fields;
    // 创建plugin的对象 
    IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}

El código de implementación está en yololayer.h/cu

class API YoloLayerPlugin : public IPluginV2IOExt {
    	
    // 设置插件名称,在注册插件时会寻找对应的插件
      const char* getPluginType() const TRT_NOEXCEPT override{
          return "YoloLayer_TRT";
      }

    
    //插件构造函数
	YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){
      /*
      	classCount:类别数量
      	netWidth:输入宽
      	netHeight:输入高
      	maxOut:最大检测数量
      	is_segmentation:是否含有实例分割
      	vYoloKernel:anchors参数
      */
    }
    
}

// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){
    // 输出结果 1+ 是在第一个位置记录解码的数量
    int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
    
    // 将存放结果的内存置为0
    for (int idx = 0; idx < batchSize; ++idx) {
    	CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
 
    // 遍历三种不同尺度的anchor
    for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
        // 调用核函数进行解码
     	CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)
    }
    
}

__global__ void CalDetection(...){
    // input:模型输出结果
    // output:decode存放地址
    // 当前线程的的全局索引ID
    int idx = threadIdx.x + blockDim.x * blockIdx.x;
    // yoloWidth * yoloHeight
    int total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数
    int bnIdx = idx / total_grid;    // 第n个batch    
    // x,y,w,h,score + 80
    int info_len_i = 5 + classes;
    // 如果带有实例分割分析,需要再加上32个分割系数
    if (is_segmentation) info_len_i += 32;
    
    // 第n个batch的推理结果开始地址
    const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);
    // 遍历三种不同尺寸的anchor
    for (int k = 0; k < kNumAnchor; ++k) {
        //每个框的置信度
    	float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
        if (box_prob < kIgnoreThresh) continue;
        for (int i = 5; i < 5 + classes; ++i) {
            // 每个类别的概率
        	float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
            // 提取最大概率以及类别ID
            if (p > max_cls_prob) {
        		max_cls_prob = p;
        		class_id = i - 5;
      		}
        }
        // 
        float *res_count = output + bnIdx * outputElem;
        // 统计decode框的数量	
        int count = (int)atomicAdd(res_count, 1);
		// 下面是按照论文的公式将预测的宽和高恢复到原图大小
		...
    }
}
4. Resumen

A través de este estudio en profundidad del código fuente abierto de trtx, aprendí cómo usar la API de trt para acelerar el modelo y también aprendí sobre la implementación del complemento. Continuaré aprendiendo los puntos de conocimiento de trtx en el futuro.

Supongo que te gusta

Origin blog.csdn.net/weixin_42108183/article/details/133212765
Recomendado
Clasificación