TensorFlow Serving使用指南

简介

TensorFlow Serving 是一个适用于机器学习模型的灵活、高性能应用系统,专为生产环境而设计。借助 TensorFlow Serving,您可以轻松部署新算法和实验,同时保留相同的服务器架构和 API。TensorFlow Serving 提供与 TensorFlow 模型的开箱即用型集成,但也可以轻松扩展以应用其他类型的模型和数据。

服务示例流程

  • 安装了模型服务器的最新 TensorFlow 服务图像
docker pull tensorflow/serving
  • 使用一个名为 的玩具模型,该模型为我们提供的预测值生成。Half Plus Two``0.5 * x + 2``x要获取此模型,请先克隆 TensorFlow Serve 存储库。
mkdir -p /tmp/tfserving
cd /tmp/tfserving
git clone https://github.com/tensorflow/serving
  • 运行 TensorFlow Serve 容器,将其指向此模型并打开 REST API 端口 (8501)
 docker run -p 8501:8501 -v /tmp/tfserving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu:/models/half_plus_two   -e MODEL_NAME=half_plus_two -t tensorflow/serving:latest-gpu &
  • docker run 是在 Docker 中运行容器的命令。
  • -p 8501:8501 是将容器内部的端口 8501 映射到主机上的端口 8501。这样可以在外部访问 TensorFlow Serving 模型服务器。
  • -v /tmp/tfserving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu:/models/half_plus_two 挂载了一个本地目录作为容器内部的模型目录。这样可以在容器中加载模型。
  • -e MODEL_NAME=half_plus_two 设置了环境变量 MODEL_NAME 为 “half_plus_two”。TensorFlow Serving 模型服务器会使用这个环境变量来确定加载哪个模型。
  • -t tensorflow/serving:latest-gpu 指定了使用的镜像为 TensorFlow Serving 的最新版本。最后的 & 意思是在后台运行,可以继续使用命令行。
  • 使用预测 API 查询模型,通过命令行工具 curl 向 TensorFlow Serving 模型服务器发送了一个预测请求,要求预测模型 “half_plus_two” 对于输入数据 [1.0, 2.0, 5.0] 的输出
curl -d '{"instances": [1.0, 2.0, 5.0]}' \
  -X POST http://localhost:8501/v1/models/half_plus_two:predict
  • 返回
{ "predictions": [2.5, 3.0, 4.5] }

模型训练、加载、推理

  • 训练模型(删除run_in_docker.sh中自动更新docker的部分)
tools/run_in_docker.sh python tensorflow_serving/example/mnist_saved_model.py \
  /tmp/mnist
  • 模型保存在
$ ls /tmp/mnist/1
saved_model.pb variables
  • 加载模型
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/mnist,target=/models/mnist \
-e MODEL_NAME=mnist -t tensorflow/serving &
  • 测试模型
tools/run_in_docker.sh python tensorflow_serving/example/mnist_client.py  --num_tests=1000 --server=127.0.0.1:8500
  • 输出结果
Inference error rate: 11.13%

查看模型情况

本地浏览器打开如下网址, 可以JSON的形式查看模型的基本运行及 metadata 情况.

查看模型概况

http://127.0.0.1:8501/v1/models/mnist
{
  "model_version_status": [
    {
      "version": "1",
      "state": "AVAILABLE",
      "status": {
        "error_code": "OK",
        "error_message": ""
      }
    }
  ]
}

查看模型metadata 数据情况

http://127.0.0.1:8501/v1/models/mnist/metadata
{
  "model_spec": {
    "name": "mnist",
    "signature_name": "",
    "version": "1"
  },
  "metadata": {
    "signature_def": {
      "signature_def": {
        "serving_default": {
          "inputs": {
            "inputs": {
              "dtype": "DT_FLOAT",
              "tensor_shape": {
                "dim": [
                  {
                    "size": "-1",
                    "name": ""
                  },
                  {
                    "size": "224",
                    "name": ""
                  },
                  {
                    "size": "224",
                    "name": ""
                  },
                  {
                    "size": "3",
                    "name": ""
                  }
                ],
                "unknown_rank": false
              },
              "name": "concat_channel:0"
            }
          },
          "outputs": {
            "outputs": {
              "dtype": "DT_INT64",
              "tensor_shape": {
                "dim": [
                  {
                    "size": "-1",
                    "name": ""
                  }
                ],
                "unknown_rank": false
              },
              "name": "predict:0"
            }
          },
          "method_name": "tensorflow/serving/predict"
        }
      }
    }
  }
}

调用接口进行预测

import cv2
import numpy as np
import  requests
import json
url = 'http://127.0.0.1:8501/v1/models/mnist:predict'
image = cv2.imread("11.jpg", cv2.IMREAD_COLOR)
image = image.astype(np.float32) / 255
image = image.tolist()
headers = {"content-type": "application/json"}
body = {
        "signature_name": "serving_default",
        "inputs": [
           image 
           ]
        }
r = requests.post(url, data = json.dumps(body), headers = headers)
text = r.text
print(text)

猜你喜欢

转载自blog.csdn.net/qq128252/article/details/128784699