网络模型保存为pb形式(二)

网上看到的一个程序,运行了下,有些地方还有错,先存下来,有些地方可以借鉴的
import numpy
import tensorflow as tf
from tensorflow import graph_util as tf_graph_util
from tensorflow.contrib import rnn as tfc_rnn


def v1(data):
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            print (session.run(s, feed_dict={x: data}))


def v2a():
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            return tf_graph_util.convert_variables_to_constants(
                session, session.graph_def, [s.op.name]), s.name


def v2ba(graph_def, s_name, data):
    with tf.Graph().as_default():
        x, s = tf.import_graph_def(graph_def,
                                   return_elements=["x:0", s_name])

        with tf.Session() as session:
            print ('2ba', session.run(s, feed_dict={x: data}))


def v2bb(graph_def, s_name, data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(2, 3, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print ('2bb', session.run(s, feed_dict={x: data}))


def v2bc(graph_def, s_name, data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print ('2bc', session.run(s, feed_dict={x: data}))


def main():
    data1 = numpy.random.random_sample((2, 3, 5))
    data2 = numpy.random.random_sample((1, 3, 5))
    v1(data1)
    model = v2a()
    v2ba(model, data1)
    v2bb(model, data1)
    v2bc(model, data1)
    v2bc(model, data2)


if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/79408409