tensorflow checkpoint loading [transfer]

Original link: https://www.cnblogs.com/hellcat/p/6925757.html

1. TensorFlow general model loading method:

The checkpoint file will record the save information, through which the most recently saved model can be located:

1
2
ckpt = tf.train.get_checkpoint_state( './model/' )
print (ckpt.model_checkpoint_path)

 

.meta file saves the current graph structure

The .index file holds the current parameter name

The .data file holds the current parameter values

The tf.train.import_meta_graph function will load the graph structure after giving the path of model.ckpt-n.meta and return the saver object

The tf.train.Saver function will return the saver object that loads the default graph

After the saver.restore function gives the path of model.ckpt-n, it will automatically find the parameter name-value file for loading

 

1. Do not load the graph structure, only load the parameters

Since we actually save the values ​​of Variable variables in our parameters, other parameter values ​​(such as batch_size), etc., we may want to modify them when restoring, but the graph structure is generally determined during training, so we can use tf .Graph().as_default() creates a new default graph (recommended to use the context environment), and uses this new graph to modify the size of parameters that are not related to variables, so as to achieve the purpose.

 
'''
Load the model saved using the original network to the redefined graph
Models can be loaded using python variable names or node names
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
    y = Net.inference_1(x, N_CLASS=5, train=False)
 
    with tf.Session() as sess:
        # There must be a Variable in front of the program for save or restore to not report an error
        # Otherwise, it will prompt that there are no variables to save
        saver = tf.train.Saver()
 
        ckpt = tf.train.get_checkpoint_state('./model/')
        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
        img = sess.run(tf.expand_dims(tf.image.resize_images(
            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,'./model/model.ckpt-0')
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            res = sess.run(y, feed_dict={x: img})
            print(global_step,sess.run(tf.argmax(res,1)))

 

  2. Load graph structure and parameters


Note that the node output tensor is available in both ways by calling the node name, and the node.name property returns the node name.

Simplified release notes:

 
# Load with graph structure
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)
             
# Only load the data, not the graph structure, you can change the value of batch_size, etc. in the new graph
# However, it should be noted that a new graph structure needs to be defined before the Saver object is instantiated, otherwise an error will be reported
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)

2. TensorFlow binary model loading method:

This loading method is generally the work of modifying the network models that have been trained by major companies on the Internet


# create a new blank image
self.graph = tf.Graph()
# Blank image as default image
with self.graph.as_default():
    # binary read model file
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # Create a new GraphDef file for temporarily loading the graph in the model
        graph_def = tf.GraphDef()
        # GraphDef loads the graph in the model
        graph_def.ParseFromString(f.read())
        # Load the graph in the GraphDef in the blank graph
        tf.import_graph_def(graph_def,name='')
        # To get a tensor in the graph, you need to use graph.get_tensor_by_name to add the tensor name
        # The tensors here can be directly used to evaluate the session's run method
        # Add a basic knowledge, such as 'conv1' is the node name, and 'conv1:0' is the tensor name, indicating the first output tensor of the node
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325866434&siteId=291194637