TensorFlow Lite简单分析

转载自:https://www.jianshu.com/p/f07c39f7b22b

一、移动端深度学习SDK比较

11月14日,谷歌终于开源了业界期待已久的移动版TensorFlow — TensorFlow Lite(Github传送门)。笔者认为,谷歌很可能凭借这款利器赢得移动端AI的生态之战。

理由有三:

1. 无缝支持通过TensorFlow训练好的神经网络模型。只需要几个简单的步骤就可以完成桌面模型到移动端模型的转换。

2. TFLite可以与Android 8.1中发布的神经网络API完美配合。而Android端版本演进的控制权是掌握在谷歌手中的,从长期看,TFLite会得到Android系统层面上的支持。

3. 质量有保证。根据以往谷歌对开源项目支持力度看,TFLite的功能迭代演进会很快,大量的bug会在第一时间修复。

此前,国内的巨头百度已经发布了MDL(传送门 )框架、腾讯发布了NCNN(传送门 )框架。下面笔者将比较下这三个移动端框架的异同之处。

相同点:

1. 只含推理(inference)功能,使用的模型文件需要通过离线的方式训练得到。

2. 最终生成的库尺寸较小,均小于500kB。

3. 为了提升执行速度,都使用了ARM NEON指令进行加速。

4. 跨平台,iOS和Android系统都支持。

不同点:

1. MDL和NCNN均是只支持Caffe框架生成的模型文件,而TfLite则毫无意外的只支持自家大哥TensorFlow框架生成的模型文件。

2. MDL支持利用iOS系统的Matal框架进行GPU加速,能够显著提升在iPhone上的运行速度,达到准实时的效果。而NCNN和TFLite还没有这个功能。


二、生成TFLite模型文件

简介

在桌面PC或是服务器上使用TensorFlow训练出来的模型文件,不能直接用在TFLite上运行,需要使用离线工具先转成.tflite文件。笔者发现官方文档中很多细节介绍的都不太明确,在使用过程中需要不断尝试。我把自己的尝试过的步骤分享出来,希望能帮助大家节省时间。

具体说来,tflite文件的生成大致分为3步:

1. 在算法训练的脚本中保存图模型文件(GraphDef)和变量文件(CheckPoint)。

2. 利用freeze_graph工具生成frozen的graphdef文件。

3. 利用toco工具,生成最终的tflite文件。

图1. 生成tflite文件的整个流程示意图

第1步:导出图模型文件和变量文件

在你的算法的训练或推理任务的脚本中,利用tensorflow.train中的write_graph和saver API来导出GraphDef及Checkpoint文件。


图2. TensorFlow中导出GraphDef文件和Checkpoint文件

其中,tf.train.write_graph一行将导出模型的GraphDef文件,实际上保存了训练的神经网络的结构图信息。存储格式为protobuffer,所以文件名后缀为pb。


图3. 导出的GraphDef文件

tf.train.saver.save一行导出的是模型的变量文件,实际上保存了整个图中所有变量目前的取值。


图4. 导出的checkpoint文件

如图4所示,实际上产生了4个文件。在后续步骤中需要用到的是nsfw_model.ckpt.data-00000-of-00001这个文件,保存了当前神经网络各参数的取值。

第2步:生成frozen的graphdef文件

在此步骤中,使用Tensorflow源代码中自带的freeze_graph工具,生成一个frozen的GraphDef文件。

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/data/deep_learning/nsfw/model/nsfw-graph.pb --input_checkpoint=/data/deep_learning/nsfw/model/nsfw_model.ckpt --input_binary=true --output_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb --output_node_names=predictions

这里有两个地方容易搞错。第一个地方,input_checkpoint参数实际上用到的文件应该是nsfw_model.ckpt.data-00000-of-00001,但是在指定文件名的时候只需要指定nsfw_model.ckpt即可。第二个地方,是output_node_names参数,此处指定的是神经网络图中的输出节点的名字,是在训练阶段的Python脚本中定义的。如下图所示,在定义网络结构时,输出节点的名称为"predictions"。则最终output_node_names需要指定为“predictions”。


图5. output_node_names参数取值与网络模型定义时的名字要对应

当然,也可以利用summarize_graph打印出模型的输入和输出节点,如:

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb

图6. 输入节点名称为input

图7. 输出节点名称为predictions

第3步:生成最终的tflite文件

在此步骤中,使用Tensorflow源代码中自带的toco工具,生成一个可供TensorFlow Lite框架使用tflite文件。其中input_arrays和output_arrays的名称需要与定义网络类型时取的名称保持一致。

bazel run --config=opt tensorflow/contrib/lite/toco:toco --input_file=/data/deep_learning/nsfw/model/frozen_nsfw.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/data/deep_learning/nsfw/model/nsfw.lite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=predictions --input_shapes=1,224,224,3

生成的nsfw.lite文件即可用于TensorFlow Lite应用。



猜你喜欢

转载自blog.csdn.net/The_star_is_at/article/details/80608435