Tensorflow.js运行Python下训练的模型

一、引言

这两天的项目需要用到Tensorflow.js来实现一个AI,尽管说Tensorflow.js本身是有训练模型的功能的,不过考虑到javascript这个东西加载资源要考虑跨域问题等种种因素。。最终还是决定使用python的tensorflow来训练模型,然后利用js端来使用模型进行运算,那么关键问题就是:js如何加载python下训练的模型

【webAI】Tensorflow.js加载预训练的model

这位博主的博客给了我很大的帮助,前两步按照他的教程来做都是没有什么问题的,不过其实还是有一些潜在的坑或者对于我这种前端小白不太友好的地方,这里我把我的整个过程都来叙述一遍吧。

注:首先在命令行中执行

pip install tensorflowjs

安装模型转换的部分,否则转换可能会报错

二、python部分

这里我用了一个更加简单的例子,and.py,让神经网络来学习异或运算(忽略这个"and"emmm),这里我直接把python代码贴出来:

#coding=utf-8#
import tensorflow as tf
import numpy as np
x_data=[[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]]	#训练数据
y_data=[[0.0],[1.0],[1.0],[0.0]]	#标签
x_test=[[0.0,1.0],[1.0,1.0]]	#测试数据
xs=tf.placeholder(tf.float32,[None,2])
ys=tf.placeholder(tf.float32,[None,1])	#定义x和y的占位符作为将要输入神经网络的变量

#构建隐藏层,假设隐藏层有20个神经元
W1=tf.Variable(tf.random_normal([2,10]))
B1=tf.Variable(tf.zeros([1,10])+0.1)
out1=tf.nn.relu(tf.matmul(xs,W1)+B1)
#构建输出层,假设输出层有一个神经元
W2=tf.Variable(tf.random_normal([10,1]))
B2=tf.Variable(tf.zeros([1,1])+0.1)
prediction=tf.add(tf.matmul(out1,W2),B2,name="model")
#计算预测值和真实值之间的误差
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)

init=tf.global_variables_initializer()	#初始化所有变量
sess=tf.Session()
sess.run(init)

for i in range(40):	#训练10次
	sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
	print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))	#打印损失值
re=sess.run(prediction,feed_dict={xs:x_test})
print(re)
for x in re:   
	if x[0]>0.5:
		print(1)
	else:
		print(0)
# 保存模型为saved_model
tf.saved_model.simple_save(sess, "./saved_model",inputs={"x": xs, }, outputs={"model": prediction, })

这个代码非常简单,由于异或运算就四种可能性,所以数据也很小,这里应该也很好理解,保存的部分也是照着那个博主的部分来写的。

三、模型转换

目录下的文件
运行and.py得到saved_model

首先运行这个python程序,会得到如图所示的文件夹:

然后,在控制台中运行如下的命令:

tensorflowjs_converter --input_format=tf_saved_model --output_node_names="model" --saved_model_tags=serve ./saved_model ./web_model
运行成功的样子
出现model.pb,说明成功了

就目前看来,这里的 --output_node_names应该就和python文件中的outputs中的字典键名一致。

扫描二维码关注公众号,回复: 3956718 查看本文章
web_model
如图所示,生成了web_model

总之,这一步应该也没有什么问题,运行成功后,会生成如图所示的web_model文件夹。

文件夹下的文件
这三个文件一会都需要用

好了,接下来是我踩的第一个坑,在web_model下有三个文件,那位博主只说了其中两个的作用,于是我傻乎乎的以为就需要这两个,然后在最后一步浏览器运行的时候,一直输出无穷大。。总之,只要记住,这个文件夹下的东西待会都是要用的就好了。

四、在web中运行

好了,这一步对我来说就是一个巨坑了,作为一个前端小白,我分别见识到了"浏览器跨域问题"和"ES6的import语句需要编译才能被浏览器识别"两大问题,第二个问题昨天花了两个多小时学会了简单的ES6标准编译结果今天查资料发现不用import语句这程序也能在浏览器中跑emmm,这样,我分别来讲述一下这两个问题吧:

1.避免使用Import语句

这个很好做到,只需要把它变成script的引用就行,这里我把我测试的文件的源代码贴出来:

<!doctype html>
<html lang="en">
<head>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
</head>
<body>
  <img style="display: none" id="cat" src="high-detail.jpg" width="224" height="224">

  <script>
const MODEL_URL = './tensorflowjs_model.pb'
const WEIGHTS_URL = './weights_manifest.json'
async function fun(){
    const model = await tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL)
      const cs = tf.tensor([[1.0,1.0],[0.0,0.0]])
	cs.print()
	model.predict(cs).print()
}
fun()
  </script>
</body>
</html>

基本上也是按那个博主的博客改的吧,有两个地方需要注意

第一个是loadFrozenModel这个函数是被tf这个对象调用的,这一点也是我今天查资料时发现的,原文是这样的:

The difference here is there is no more "tf_converter", only "tf". You call "tf.loadFrozenModel" like you would any other tf op.

大概意思就是说以后都会变成tf.loadxxx这样子的了

第二个是,这个函数写完以后,记得最后要调用,否则打开网页什么都看不到,毕竟我也是花了五分钟看着空白的控制台懵逼的人。

2.浏览器跨域问题

这个世界就是这么神奇,我写博客这会居然可以直接访问同目录下的pb文件和json文件了,不过之前一直都会报CORS问题,之前我是通过安装配置一个web服务器来访问页面的,我用了tomcat,将HTML文件和pb、json等三个文件置于其webapps的目录下然后启动服务器就能正常访问了:

目录情况
webapps下的情况
访问效果
访问指定的URL,控制台中打印出了结果

 

当然,如果你根本就没遇到跨域问题,只要忽略这一步即可。

猜你喜欢

转载自blog.csdn.net/zekdot/article/details/82913636
今日推荐