tensorflow for python做模型训练、tensorflow for java做模型预测(同时生成pb文件和variable变量)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wshzd/article/details/88846136

python脚本(此代码为线性回归的demo)

#!/usr/bin/python
# -*- coding:utf-8 -*-

import tensorflow as tf
from tensorflow import saved_model as sm
import numpy as np

x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data) + noise

x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1])

Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1
L1 = tf.nn.tanh(Wx_plus_b_L1)

Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2)

loss = tf.reduce_mean(tf.square(y - prediction))
# 使用梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    for _ in range(2000):
        sess.run(train_step, feed_dict={x: x_data, y: y_data})
        
    # 获得预测值
    prediction_value = sess.run(prediction, feed_dict={x: x_data})
    print(prediction_value)

    # 1、构建SavedModelBuilder的实例builder,并设置模型导出路径
    path = 'hdfs://default/home/hdp_ubu_tech_wei/resultdata/wuxian_prs/hometown/hezhidong/LinearModel'
    builder = sm.builder.SavedModelBuilder(path)

    # 2、定义模型服务的SignatureDef protobuf
    X_TensorInfo = sm.utils.build_tensor_info(x)
    y_TensorInfo = sm.utils.build_tensor_info(prediction)
    SignatureDef = sm.signature_def_utils.build_signature_def(
        inputs={'input_1': X_TensorInfo},
        outputs={'output': y_TensorInfo},
        method_name=sm.signature_constants.REGRESS_METHOD_NAME
    )

    # 3、将 graph 和变量等信息写入 MetaGraphDef protobuf
    builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING],
                                         signature_def_map={'prediction':SignatureDef,}
                                         )
    builder.save()


 

猜你喜欢

转载自blog.csdn.net/wshzd/article/details/88846136