Staging Tensorflow model of recovery trimming save load

  • Staging model (* .index is the parameter name, *. Meta model diagram, *. Data * parameter)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

sess.close()
  • Staging model (the same model can not be saved many times to save the model map to save time)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False)
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False)

sess.close()
  • Recovery model (you do not need to manually generate network * .meta file)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([weights]))

sess.close()
  • Recovery model (generated from * .meta network file)
tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()
  • Recovery model (model can be saved multiple times in a folder, checkpoint file automatically records all model names and model name of the last record)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)

print(sess.run([weights]))

sess.close()
  • Fine-tune the model (some parameters before resuming training model, plus new parameters, to continue training)
def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True):
    ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
    ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
    if not include_global_step:
        ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
    vars_in_ckpt = {}
    for variable_name, variable in sorted(variables.items()):
        if variable_name in ckpt_vars_to_shape_map:
            if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
                vars_in_ckpt[variable_name] = variable
    return vars_in_ckpt

tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
other_weights = tf.Variable(tf.zeros([10, 10]))

variables_to_init = tf.global_variables()
variables_to_init_dict = {var.op.name: var for var in variables_to_init}
available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict,
    "%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False)
tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))

sess.close()
  • Save the model (binary model)
from tensorflow.python.framework.graph_util import convert_variables_to_constants

tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['weights'])
with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output:
    output.write(graph_out.SerializeToString())

sess.close()
  • Load Model (binary model)
tf.reset_default_graph()

sess = tf.Session()
with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def,name='')
sess.run(tf.global_variables_initializer())

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()

 

 

 

references:

https://blog.csdn.net/loveliuzz/article/details/81661875

https://www.cnblogs.com/bbird/p/9951943.html

https://blog.csdn.net/gzj_1101/article/details/80299610

 

Guess you like

Origin www.cnblogs.com/jhc888007/p/11620821.html