从0开始使用TensorFlow.js进行深度学习推断

原因

最近项目里面考虑要在微信小程序里面加上识别触发以及检测追踪的功能,所以需要调研一下在Web端进行深度学习前向的可行性。这个可行性主要包含两个方面,一是模型是否支持:能否从现有模型转换为tfjs支持的模型;二是性能:tfjs的前向推断效率是否能满足我们的需求。

名词解释

JavaScript

JavaScript通常简写为JS,是一种高级的解释性脚本语言,遵循ECMAScript标准。JS支持花括号语法、动态类型、基于原型的面向对象等编程语言特性。

JS和HTML以及CSS一起,构成了World Wide Web的技术核心。JS能实现网页的交互,而且是web应用的必须部分。绝大多数的网页都使用了JS,而且大量的浏览器都有专用的JS解释器来执行JS脚本。

作为一个多范式语言,JS支持事件驱动、函数式以及其他必要的语言格式。它提供了文字、数组、日期、正则表达式以及树形文件结构的API,但是JS本身并不包括任何I/O,比如网络、存储、或者图形功能。这些功能都必须依赖宿主平台提供。

WebGL

WebGL是一组JavaScript的API,可用来在兼容的浏览器上渲染可交互的2d或3d图形,且不需要使用其他插件。WebGL和其他浏览器标准高度融合,允许使用GPU来对图形渲染或图像处理进行加速。

TensorFlow.js(tfjs)

tfjs是TensorFlow的JavaScript语言机器学习库。支持完整的模型读取、训练以及前向推断等功能,并使用WebGL对整个过程进行加速。可以参考官方例子

node.js

Node.js是一个开源的、跨平台的JavaScript运行时环境,可以在浏览器之外运行JavaScript脚本,通常用来作为服务器端开发语言。

方案

1.使用node.js启动一个本地的https服务;
2.实现本地页面的html以及对应的JS代码index.js;
3.在初始化时读取tfjs版的深度学习模型;
4.以全0Tensor作为输入,点击对应按钮进行若干次推断;
5.统计平均单次推断时间,并显示在页面上.

实现

安装node.js

参考此文

启动https服务

参考此文
实现一个server.js,其中包括https服务,然后node server.js启动服务。

实现浏览器页面html

参考此文

实现页面的JS代码

var text = document.getElementById("text_area");
var MOBILENET_MODEL_PATH = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);

var timeBegin = Date.now();
mobilenet.predict(tf.zeros([1, HAND_IMAGE_SIZE, HAND_IMAGE_SIZE, 3])).dispose();
var timeEnd = Date.now();

text.value = (timeEnd - timeBegin)/10.0;

使用tf.loadLayersModel加载模型,使用模型的predict方法进行推断。

在浏览器中打开index.html

难点

在项目中使用tfjs

因为我们是在前端页面上使用,所以最简单的方法,只需要加载这个库就可以,在html中添加:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
就可以直接使用tf.xxx。需要注意的是,script标签是顺序加载的,所以依赖库最好放到index.js上面。

模型转换

参考此文
在这里插入图片描述

加载模型

如果参考了官方例子以及官方API,会发现一共有三个load接口:

tf.loadGraphModel (modelUrl, options?)
tf.loadLayersModel (pathOrIOHandler, options?)
tf.loadFrozenModel (modelUrl, weightsManifestUrl, requestOption?)

其中第三个是0.13版本的API,前两个才是2.0版本的API,所以果断放弃第三个。再来对比前两个,官方是这么说的:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OeivLhoR-1574145990615)(evernotecid://1CA468D3-8108-4F95-9FF0-B3384CD16BE9/appyinxiangcom/22266324/ENResource/p627)]
简单点说就是如果模型是使用tf.LayersModel.save()方法保存的,就使用tf.loadLayersModel加载,否则一律使用tf.loadGraphModel

浏览器版本兼容性

如果一切顺利,应该是可以正常打开网页并加载运行模型。但是还有一个坑,如果浏览器版本过低,有可能会不支持某些特性,比如async function等:
在这里插入图片描述

调试

如果不幸遇到问题,可能无法运行成功,这个时候非常想能够有一个调试工具,可以在手机上直接调试,这里推荐vConsole,简单易用,参考此文

结果

一些可能的参考结果:
mobilenet_v2_1.0_224
在1+6手机(高通骁龙845)上,UC浏览器,平均每次predict需要36ms;
在iPhone6s plus(苹果 A9+M9),safari浏览器,平均60ms。

参考

https://www.runoob.com/w3cnote/javascript-settimeout-usage.html
https://www.cnblogs.com/lvdabao/p/es6-promise-1.html
https://www.tensorflow.org/js/tutorials/setup
https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Reference/Statements/async_function
https://medium.com/cubo-ai/tensorflow-js-初探-yolo-3be7e63f7c96

发布了42 篇原创文章 · 获赞 33 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/gaussrieman123/article/details/103142160