[Introduction] With the popularity of TensorFlow, more and more industries hope to integrate a large number of existing TensorFlow codes and models in Github into their own business systems. How to use TensorFlow in common programming languages (Java, NodeJS, etc.) has become a relatively common problem. Expert member Hujun gave you a detailed introduction to two methods of using TensorFlow in Java, and focused on how to use the official TensorFlow Java API to call an existing TensorFlow model.
1. Two ways for Java to call TensorFlow
There are roughly two ways to call TensorFlow using Java:
Call the trained pb model directly using the TensorFlow official API:
https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary
(Recommended) Use KerasServer to host TensorFlow/Keras code and models:
https://github.com/CrawlScript/KerasServer
Although the trained pb model can be directly connected to the TensorFlow official Java API, in actual use, there are still cumbersome codes related to cross-language connection. For example, although there is already a TensorFlow-based text classification code written in Python, the input of the TensorFlow Java API needs to be quantified text, so we need to use Java to re-implement the word segmentation, from string to index that has been implemented in Python code. conversion and other preprocessing operations (these operations also depend on data such as word lists that Python code depends on). In addition, since Java does not have numpy support, when constructing multi-dimensional arrays as input, operations similar to loops are still used, which is very cumbersome.
KerasServer supports restful interaction, so it can support calling TensorFlow/Keras from any programming language. Since the server side of KerasServer provides Python API, existing TensorFlow/Keras Python code and models can be directly converted into KerasServer API for calls by Java/c/c++/C#/Python/NodeJS/Browser Javascript, etc. In other languages, tedious data preprocessing operations are performed.
For example, Java can directly submit text data that needs to be classified to KerasServer, and KerasServer can use existing Python code to perform word segmentation, preprocessing and other operations on strings.
This tutorial introduces how to use TensorFlow official Java API to call TensorFlow (Python) trained models. The code for the tutorial can be found in the dedicated Github project:
https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples
(1) Python dependency
TensorFlow
pip install tf-nightly
(2) Java dependencies
This tutorial uses the official Java interface provided by TensorFlow, so we need to import the following Maven dependencies:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.5.0</version> </dependency>In addition, there are some tool class dependencies:
<dependency> <groupId>commons-io</groupId> <artifactId>commons-io</artifactId> <version>2.6</version> </dependency>3.保存pb模型
#coding=utf-8 import tensorflow as tf # 定义图 x = tf.placeholder(tf.float32, name="x") y = tf.get_variable("y", initializer=10.0) z = tf.log(x + y, name="z") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 进行一些训练代码,此处省略 # xxxxxxxxxxxx # 显示图中的节点 print([n.name for n in sess.graph.as_graph_def().node]) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["z"]) # 保存图为pb文件 with open('model.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())4.在Java中调用TensorFlow的图(pb模型)
模型的执行与Python类似,依然是导入图,建立Session,指定输入(feed)和输出(fetch)。
import org.apache.commons.io.IOUtils; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import java.io.FileInputStream; import java.io.IOException; public class DemoImportGraph { public static void main(String[] args) throws IOException { try (Graph graph = new Graph()) { //导入图 byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model.pb")); graph.importGraphDef(graphBytes); //根据图建立Session try(Session session = new Session(graph)){ //相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0}) float z = session.runner() .feed("x", Tensor.create(10.0f)) .fetch("z").run().get(0).floatValue(); System.out.println("z = " + z); } } } }运行结果:
z = 2.9957323
完整代码链接:
https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples
文章来源:https://mp.weixin.qq.com/s/hn-LqyREkusxP2TOWfTJ6g