Use TensorFlow official Java API to call TensorFlow model (with code)

[Introduction] With the popularity of TensorFlow, more and more industries hope to integrate a large number of existing TensorFlow codes and models in Github into their own business systems. How to use TensorFlow in common programming languages ​​(Java, NodeJS, etc.) has become a relatively common problem. Expert member Hujun gave you a detailed introduction to two methods of using TensorFlow in Java, and focused on how to use the official TensorFlow Java API to call an existing TensorFlow model.

1. Two ways for Java to call TensorFlow

There are roughly two ways to call TensorFlow using Java:

  • Call the trained pb model directly using the TensorFlow official API: 

    https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary

  • (Recommended) Use KerasServer to host TensorFlow/Keras code and models: 

    https://github.com/CrawlScript/KerasServer


Although the trained pb model can be directly connected to the TensorFlow official Java API, in actual use, there are still cumbersome codes related to cross-language connection. For example, although there is already a TensorFlow-based text classification code written in Python, the input of the TensorFlow Java API needs to be quantified text, so we need to use Java to re-implement the word segmentation, from string to index that has been implemented in Python code. conversion and other preprocessing operations (these operations also depend on data such as word lists that Python code depends on). In addition, since Java does not have numpy support, when constructing multi-dimensional arrays as input, operations similar to loops are still used, which is very cumbersome.


KerasServer supports restful interaction, so it can support calling TensorFlow/Keras from any programming language. Since the server side of KerasServer provides Python API, existing TensorFlow/Keras Python code and models can be directly converted into KerasServer API for calls by Java/c/c++/C#/Python/NodeJS/Browser Javascript, etc. In other languages, tedious data preprocessing operations are performed.


For example, Java can directly submit text data that needs to be classified to KerasServer, and KerasServer can use existing Python code to perform word segmentation, preprocessing and other operations on strings.


This tutorial introduces how to use TensorFlow official Java API to call TensorFlow (Python) trained models. The code for the tutorial can be found in the dedicated Github project:

https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples

2. Dependent library

(1) Python dependency

        TensorFlow

pip install tf-nightly

(2) Java dependencies

This tutorial uses the official Java interface provided by TensorFlow, so we need to import the following Maven dependencies:

<dependency>
   <groupId>org.tensorflow</groupId>
   <artifactId>tensorflow</artifactId>
   <version>1.5.0</version>
</dependency>
In addition, there are some tool class dependencies:
<dependency>
   <groupId>commons-io</groupId>
   <artifactId>commons-io</artifactId>
   <version>2.6</version>
</dependency>
3.保存pb模型
下面的代码中,x是图的输入,z是图的输出。在代码的最后,调用tf.graph_util.convert_variables_to_constants 将图进行转换,最后将图保存为模型文件(pb)。
#coding=utf-8
import tensorflow as tf


# 定义图
x = tf.placeholder(tf.float32, name="x")
y = tf.get_variable("y", initializer=10.0)
z = tf.log(x + y, name="z")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 进行一些训练代码,此处省略
    # xxxxxxxxxxxx

    # 显示图中的节点
    print([n.name for n in sess.graph.as_graph_def().node])
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["z"])

    # 保存图为pb文件
    with open('model.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())
4.在Java中调用TensorFlow的图(pb模型)

模型的执行与Python类似,依然是导入图,建立Session,指定输入(feed)和输出(fetch)。

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.IOException;

public class DemoImportGraph {

    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //导入图
            byte[] graphBytes = IOUtils.toByteArray(new 
            FileInputStream("model.pb"));
            graph.importGraphDef(graphBytes);

            //根据图建立Session
            try(Session session = new Session(graph)){
                //相当于TensorFlow Python中的sess.run(z, 
feed_dict = {'x': 10.0})
                float z = session.runner()
                        .feed("x", Tensor.create(10.0f))
                        .fetch("z").run().get(0).floatValue();
                System.out.println("z = " + z);
            }
        }

    }
}
运行结果:
z = 2.9957323

完整代码链接:

https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples


文章来源:https://mp.weixin.qq.com/s/hn-LqyREkusxP2TOWfTJ6g




Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325733872&siteId=291194637