网上看到的一个程序,运行了下,有些地方还有错,先存下来,有些地方可以借鉴的
import numpy import tensorflow as tf from tensorflow import graph_util as tf_graph_util from tensorflow.contrib import rnn as tfc_rnn def v1(data): with tf.Graph().as_default(): tf.set_random_seed(1) x = tf.placeholder(tf.float32, shape=(None, None, 5)) _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32) with tf.Session() as session: session.run(tf.global_variables_initializer()) print (session.run(s, feed_dict={x: data})) def v2a(): with tf.Graph().as_default(): tf.set_random_seed(1) x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x") _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32) with tf.Session() as session: session.run(tf.global_variables_initializer()) return tf_graph_util.convert_variables_to_constants( session, session.graph_def, [s.op.name]), s.name def v2ba(graph_def, s_name, data): with tf.Graph().as_default(): x, s = tf.import_graph_def(graph_def, return_elements=["x:0", s_name]) with tf.Session() as session: print ('2ba', session.run(s, feed_dict={x: data})) def v2bb(graph_def, s_name, data): with tf.Graph().as_default(): x = tf.placeholder(tf.float32, shape=(2, 3, 5)) [s] = tf.import_graph_def(graph_def, input_map={"x:0": x}, return_elements=[s_name]) with tf.Session() as session: print ('2bb', session.run(s, feed_dict={x: data})) def v2bc(graph_def, s_name, data): with tf.Graph().as_default(): x = tf.placeholder(tf.float32, shape=(None, None, 5)) [s] = tf.import_graph_def(graph_def, input_map={"x:0": x}, return_elements=[s_name]) with tf.Session() as session: print ('2bc', session.run(s, feed_dict={x: data})) def main(): data1 = numpy.random.random_sample((2, 3, 5)) data2 = numpy.random.random_sample((1, 3, 5)) v1(data1) model = v2a() v2ba(model, data1) v2bb(model, data1) v2bc(model, data1) v2bc(model, data2) if __name__ == "__main__": main()