简单的demo带你将TensorFlow的pb格式的模型移植到Android平台

最近想把在tensorflow上训练的模型移植到安卓上运行,看了一下网上的例子,感觉都很复杂,对于我这种不太会安卓代码的人很不友好,所以决定自己摸索,再看了tensorflow官方的demo后,决定写出了下面这个简易版demo,带你快速了解如何将pb模型移植到安卓上运行。

我的环境:

  • windows10
  • python3.7
  • tensorflow-gpu1.14
  • pycharm
  • android studio

这个demo并不是针对移植神经网络模型,而是针对pb文件的调用,所以我只写了一段简单的代码来生成pb文件。下面先贴出我的pb文件生成代码。

# -*- coding:utf-8 -*-
# 这是一个基于tensorflow的简单计算,并且保存模型为pd文件
import tensorflow as tf

sess = tf.Session()

matrix1 = tf.placeholder(tf.float32, [2, ], name='input1')
matrix2 = tf.placeholder(tf.float32, [2, ], name='input2')
mat_add = tf.add(matrix1, matrix2, name='output1')
mat_sub = tf.subtract(matrix1, matrix2, name='output2')

res1 = sess.run(mat_add, feed_dict={matrix1: [4, 6], matrix2: [3, 1]})
res2 = sess.run(mat_sub, feed_dict={matrix1: [4, 6], matrix2: [3, 1]})
print("res1=", res1)
print("res2=", res2)

# 保存二进制模型
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                output_node_names=['output1', 'output2'])  # output_node_names指定要保存哪些输出tensor
with tf.gfile.FastGFile('test.pb', mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

sess.close()

这个模型非常简单,两个输入tensor两个输出tensor,一个加法运算OP,一个减法运算OP。这么简单的模型,应该算是非常亲民了吧。

然后我们使用android studio的模板生成一个新的工程,就是那个能够直接打印出hollow world的那个模板。

首先添加一个名为assets资源文件夹到app/src/main/里面,然后把test.pb文件放入assets文件夹中,如下图所示:

之后修改在app/src文件夹下的build.gradle,在“dependencies”里添加(印象中需要翻墙,想不翻墙的话就自己编译库吧):

implementation 'org.tensorflow:tensorflow-android:1.13.1'

用来加入tensorflow的aar库。如下图所示:(可以版本号改为+号,来达到不指定版本的目的)

最后在MainActivity.java中进行修改,调用pb模型并运行和打印结果。

import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
//加入tensorflow支持
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        String MODEL_FILE = "file:///android_asset/test.pb";  //pb文件的位置
        TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(),MODEL_FILE);  //从载入模型

        float[] input1 = new float[3];  //一般采用float数组作为输入、输出
        float[] input2 = new float[3];  //具体使用什么类型依网络中实际情况而定
        input1[0] = (float) 5.0; input1[1] = (float) 6.0; input1[2] = (float) 1.0;
        input2[0] = (float) 2.0; input2[1] = (float) 3.0; input2[2] = (float) 2.0;
        float[] output1 = new float[3];
        float[] output2 = new float[3];

        //喂入输入数据,格式为“输入tensor名称(与模型中设定name的一致)”,“输入数据”,“数据的shape”
        inferenceInterface.feed("input1", input1, new long[]{1,3}); 
        inferenceInterface.feed("input2", input2, new long[]{1,3});
        
        //执行run,格式为“输出tensor名称(与模型中设定name的一致)”
        inferenceInterface.run(new String[]{"output1","output2"});

        //获取输出结果,格式为“输出tensor名称(与模型中设定name的一致)”,“输出数据”
        inferenceInterface.fetch("output1", output1);
        inferenceInterface.fetch("output2", output2);
        
        //打印结果
        for(float f : output1)
            Log.e("111111", "output1: " + f);
        for(float f : output2)
            Log.e("111111", "output2: " + f);
    }
}

然后直接安装运行就可以看到如下打印结果:

结果正确。

细心地小伙伴应该发现了一个问题,就是在安卓端每个input都输入了3个数,而模型定义的placeholder的输入shape为2个数。似乎tensorflow并没有对这个输入的大小进行判断,具体原因我也不清楚。不过我这个模型过于简单,计算内容的确不会严重依赖于shape,所以还是劝大家严格按照模型约定的大小进行输入。

另外,我还尝试过将python中的placeholder改成变量,name保持不变。它生成的模型依然可以运行出正确的结果,我估计tensorflow只是强制把输入按name进行对应,也不管它的shape或者类型(placeholder,变量等类型)。

最终总结:

1.在安卓端调用pb模型主要依靠TensorFlowInferenceInterface。

2.模型运行与python上类似,使用run指定输出tensor,就可以运行相应的节点得到结果

3.输入输出tensor的name必须严格对应。

4.python上的 sess.run中的feed_dict被作为一个单独的API,在TensorFlowInferenceInterface里的feed。

5.使用TensorFlowInferenceInterface里的fetch得到推理结果。

其实我也是个tensorflow和android的初学者,如果有什么错误,希望大家能够帮我指出,谢谢大家。

猜你喜欢

转载自blog.csdn.net/qq_19313495/article/details/97375711