tensorflow for python做模型训练、tensorflow for java做模型预测(只生产pb文件,不生产variable的情况下)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wshzd/article/details/88840792

python脚本

#!/usr/bin/python
# -*- coding:utf-8 -*-

import tensorflow as tf
from tensorflow import saved_model as sm
import numpy as np

x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data) + noise

x = tf.placeholder(tf.float32,[None,1],name="x")
y = tf.placeholder(tf.float32,[None,1])

Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1
L1 = tf.nn.tanh(Wx_plus_b_L1)

Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2,name="y")

loss = tf.reduce_mean(tf.square(y - prediction))
# 使用梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    for _ in range(2000):
        sess.run(train_step, feed_dict={x: x_data, y: y_data})

    # 获得预测值
    prediction_value = sess.run(prediction, feed_dict={x: x_data})
    print([n.name for n in sess.graph.as_graph_def().node])
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["y"])

    # 保存图为pb文件
    with open('model.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

java脚本
package XXX;

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.util.Random;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.FloatBuffer;

public class tflinearmodel {

    public static void main(String[] args) throws IOException {
        String path = "xxx/model.pb";
        try (Graph graph = new Graph()) {
            //导入图
            byte[] graphBytes = IOUtils.toByteArray(new
                    FileInputStream(path));
            graph.importGraphDef(graphBytes);
            System.out.println(graphBytes);

            //根据图建立Session
            try(Session session = new Session(graph)){
                System.out.println("hello");
                float[][] input = new float[1][1];
                input[0] = new float[]{10.0f};
                //相当于TensorFlow Python中的sess.run(z,feed_dict = ({'x': 10.0})
                //float z = session.runner()
                //        .feed("x", Tensor.create(input))
                //        .fetch("y").run().get(0).floatValue();  此脚本增加floatValue()方法后会报错
                //Exception in thread "main" java.lang.IllegalStateException: Tensor is not a scalar                  //脚本修改如下则执行正确
                Tensor z = session.runner()
                           .feed("x", Tensor.create(input))
                           .fetch("y").run().get(0);                   float[][] zz = (float[][]) z.copyTo(new float[1][1]);
                   System.out.println("y = " + zz[0][0]);            }
        }

    }
}

猜你喜欢

转载自blog.csdn.net/wshzd/article/details/88840792
今日推荐