tensorflow model server 回归模型保存与调用方法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011961856/article/details/79686938

安装tensorfow model server:

安装依赖包,

sudo apt-get update && sudo apt-get install -y \
        build-essential \
        curl \
        libcurl3-dev \
        git \
        libfreetype6-dev \
        libpng12-dev \
        libzmq3-dev \
        pkg-config \
        python-dev \
        python-numpy \
        python-pip \
        software-properties-common \
        swig \
        zip \
        zlib1g-dev

安装tensorflow-serving-api,

pip install tensorflow-serving-api

安装server,

sudo apt-get update && sudo apt-get install tensorflow-model-server

将模型保存

设置保存模型路径,模型版本,

# Export inference model.
output_dir='pix2pix_model'
model_version=1
output_path = os.path.join(
    tf.compat.as_bytes(output_dir),
    tf.compat.as_bytes(str(model_version)))
print('Exporting trained model to', output_path)
builder = tf.saved_model.builder.SavedModelBuilder(output_path)

使用tf.saved_model.utils.build_tensor_info,将模型输入,输出转换为server变量形式,并保存

image_size=512
images = tf.placeholder(tf.float32, [None, image_size, image_size,3])#模型输入
model=pix2pix()
# Run inference.
outputs = model.sampler(images)#模型输出
saver = tf.train.Saver()
saver.restore(sess, 'checkpoint-0')#加载已经训练好的模型参数
inputs_tensor_info = tf.saved_model.utils.build_tensor_info(images)
outputs_tensor_info = tf.saved_model.utils.build_tensor_info(outputs)
prediction_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'images': inputs_tensor_info},
        outputs={
            'outputs': outputs_tensor_info,

        },
        method_name=tf.saved_model.signature_constants.REGRESS_METHOD_NAME
    ))
builder.save()#保存

由于我这里使用的时回归模型,因此method_name=tf.saved_model.signature_constants.**REGRESS**_METHOD_NAME

若是分类模型,则改为,

method_name=tf.saved_model.signature_constants.**PREDICT**_METHOD_NAME

保存之后,便可以在对应的路径下得到对应版本的模型文件,例如,本文中,保存路径为pix2pix_model,版本为1,则有,

这里写图片描述

使用方法

按照上述方法保存模型后,便可以启动客户端,命令如下:

tensorflow_model_server –port=9000 –model_name=pix2pix –model_base_path=/home/detection/tensorflow_serving/example/data/pix2pix_model/

注意,model_base_path必须为绝对路径,否则会出错.

客户端调用model:

python pix2pix_client.py –num_tests=1000 –server=localhost:9000

pix2pix_clinet.py定义如下,

def main(_):
  host, port = FLAGS.server.split(':')
  channel = implementations.insecure_channel(host, int(port))
  stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
  data = imread(FLAGS.image)
  data = data / 127.5 - 1.
  image_size=512
  sample=[]
  sample.append(data)
  sample_image = np.asarray(sample).astype(np.float32)
  request = predict_pb2.PredictRequest()
  request.model_spec.name = 'pix2pix'
  request.model_spec.signature_name =tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  request.inputs['images'].CopyFrom(
      tf.contrib.util.make_tensor_proto(sample_image, shape=[1, image_size, image_size,3]))
  result_future = stub.Predict.future(request, 5.0)  # 5 seconds
  response = np.array(
    result_future.result().outputs['outputs'].float_val)
  out=(response.reshape((512,512,3))+1)*127.5
  out= cv2.cvtColor(out.astype(np.float32), cv2.COLOR_BGR2RGB)
  cv2.imwrite('data/test_result/' + '1.jpg', out)

完整代码可以参考我的github:https://github.com/qinghua2016/pix2pix_server

c++调用可参考:
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_client.cc

java调用可以参考:

https://github.com/foxgem/how-to/blob/master/tensorflow/clients/src/main/java/foxgem/Launcher.java

猜你喜欢

转载自blog.csdn.net/u011961856/article/details/79686938