【webAI】Tensorflow.js加载预训练的model

环境准备

  • win10
  • python3.6
  • pip install tensorflow
  • pip install tensorflowjs

训练并保存tensorflow模型为saved_model

# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

# 下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 初始化session
sess = tf.InteractiveSession()

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

# 神经网络参数
n_input = 784
n_node = 256
n_out = 10

x = tf.placeholder(tf.float32, [None, n_input], name="x")
y_ = tf.placeholder(tf.float32, [None, n_out])

# 第一层
W = weight_variable([n_input, n_node])
b = bias_variable([n_node])
layer_h = tf.nn.relu(tf.matmul(x, W) + b)

# 第二层
W_out = bias_variable([n_node, n_out])
b_out = bias_variable([n_out])
y = tf.nn.relu(tf.matmul(layer_h, W_out) + b_out)

softmax = tf.nn.softmax(y, name="softmax")

# LOSS损失函数
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))

correct_prediction = tf.equal(tf.argmax(softmax, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 训练模型
train_step = tf.train.AdamOptimizer().minimize(cross_entropy)

tf.global_variables_initializer().run()
for i in range(2000):
  batch = mnist.train.next_batch(50)
  if i % 200 == 0:
    train_accuracy = accuracy.eval(feed_dict={
        x: batch[0], y_: batch[1]})
    print('step %d, training accuracy %g' % (i, train_accuracy))
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})

print('test accuracy %g' % accuracy.eval(feed_dict={
    x: mnist.test.images, y_: mnist.test.labels}))

# 保存模型为saved_model
tf.saved_model.simple_save(sess, "./saved_model",
                           inputs={"x": x, }, outputs={"softmax": softmax, })

转换tensorflow的模型

tensorflowjs_converter --input_format=tf_saved_model \
  --output_node_names="softmax" \
  --saved_model_tags=serve ./saved_model \
  ./web_model
  • 转换后的模型文件
  • tensorflowjs_model.pb 为 tensorflow.js能识别的模型
  • weights_manifest.json 为 tensorflow.js能识别的模型参数文件


这里写图片描述


Tensorflow.js加载转换后的模型

import * as tf from '@tensorflow/tfjs'
import {loadFrozenModel} from '@tensorflow/tfjs-converter'

const MODEL_URL = 'tensorflowjs_model.pb'
const WEIGHTS_URL = 'weights_manifest.json'

async function predict() {
    try {
      const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
      var xs = tf.tensor2d([pixels])
      var output = model.execute({x: xs})
      console.log(output.dataSync())
      return output
    } catch (e) {
      console.log(e)
    }
 }

猜你喜欢

转载自blog.csdn.net/ns2250225/article/details/80030371
今日推荐