CenterFaceモデルからTensorRTへ

CenterFaceモデルからTensorRTへ

1.githubオープンソースコード

CenterFaceTensorRT推論のオープンソースコードの場所はhttps://github.com/linghu8812/tensorrt_inference/tree/master/CenterFaceにあり、作成者のオープンソースコードの場所はhttps://github.com/Star-Clouds/CenterFaceにあります。論文のarxivアドレスはhttps://arxiv.org/abs/1911.03599です。

2.ONNXモデルを書き直します

著者は2つのオープンソースモデルをgithubに配置しました。TensorRTを変換すると、これら2つのモデルがPyTorchを介して変換されたことがログからわかります。著者は、PyTorchを介してモデルを構築するためのコードをオープンソース化しておらず、オープンソースもしていません。モデルのPyTorchバージョン。netron、あなたは著者のオープンソースONNXモデルの解像度があることがわかります32×32 32 \回323 2××3 2入力TensorRT推論エンジンが固定されているので、入力と出力サイズONNXモデルを修正する必要があります。下の図に示すように、入力と出力の解像度は640×640 640 \ times640に変更されます。6 4 0××6 4 0および160×160 160 \ 160倍1 6 0××1 6 0

python3 export_onnx.py

図1入力サイズ
図2出力サイズ

3.ONNXモデルをTensorRTモデルに変換します

3.1概要

TensorRTモデルはTensorRTの推論エンジンであり、コードはC ++で実装されています。関連する構成はconfig.yamlファイルに書き込まれ、engine_fileパスが存在する場合は読み取られengine_file、そうでない場合はonnx_file生成されengine_fileます。

void CenterFace::LoadEngine() {
    // create and load engine
    std::fstream existEngine;
    existEngine.open(engine_file, std::ios::in);
    if (existEngine) {
        readTrtFile(engine_file, engine);
        assert(engine != nullptr);
    } else {
        onnxToTRTModel(onnx_file, engine_file, engine, BATCH_SIZE);
        assert(engine != nullptr);
    }
}

config.yamlファイルは、推論のバッチサイズ、画像のサイズ、検出のしきい値を設定するだけで済みます。これは、回帰にアンカーを必要とするモデルと比較した場合のアンカーフリーモデルの利点です。

CenterFace:
  onnx_file:     "../centerface.onnx"
  engine_file:   "../centerface.trt"
  BATCH_SIZE:    1
  INPUT_CHANNEL: 3
  IMAGE_WIDTH:   640
  IMAGE_HEIGHT:  640
  obj_threshold: 0.5
  nms_threshold: 0.45

画像の元のデータをテンソルに変換する場合、画像のアスペクト比を維持する必要があります。対応するコードは次のとおりです。

float ratio = float(IMAGE_WIDTH) / float(src_img.cols) < float(IMAGE_HEIGHT) / float(src_img.rows) ? float(IMAGE_WIDTH) / float(src_img.cols) : float(IMAGE_HEIGHT) / float(src_img.rows);
cv::Mat flt_img = cv::Mat::zeros(cv::Size(IMAGE_WIDTH, IMAGE_HEIGHT), CV_8UC3);
cv::Mat rsz_img;
cv::resize(src_img, rsz_img, cv::Size(), ratio, ratio);
rsz_img.copyTo(flt_img(cv::Rect(0, 0, rsz_img.cols, rsz_img.rows)));
flt_img.convertTo(flt_img, CV_32FC3);

3.2コンパイル

次のコマンドを使用してプロジェクトをコンパイルし、生成しますyolov5_trt

mkdir build && cd build
cmake ..
make -j

3.3操作

次のコマンドでプロジェクトを実行して、推論結果を取得します

./CenterFace_trt ../config.yaml ../samples

4.推論結果

推論結果を次の図に示します。
推論結果

おすすめ

転載: blog.csdn.net/linghu8812/article/details/109549702