Original link: https://www.cnblogs.com/hellcat/p/6925757.html
1. TensorFlow general model loading method:
The checkpoint file will record the save information, through which the most recently saved model can be located:
1
2
|
ckpt
=
tf.train.get_checkpoint_state(
'./model/'
)
print
(ckpt.model_checkpoint_path)
|
.meta file saves the current graph structure
The .index file holds the current parameter name
The .data file holds the current parameter values
The tf.train.import_meta_graph function will load the graph structure after giving the path of model.ckpt-n.meta and return the saver object
The tf.train.Saver function will return the saver object that loads the default graph
After the saver.restore function gives the path of model.ckpt-n, it will automatically find the parameter name-value file for loading
1. Do not load the graph structure, only load the parameters
Since we actually save the values of Variable variables in our parameters, other parameter values (such as batch_size), etc., we may want to modify them when restoring, but the graph structure is generally determined during training, so we can use tf .Graph().as_default() creates a new default graph (recommended to use the context environment), and uses this new graph to modify the size of parameters that are not related to variables, so as to achieve the purpose.
''' Load the model saved using the original network to the redefined graph Models can be loaded using python variable names or node names ''' import AlexNet as Net import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3]) y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess: # There must be a Variable in front of the program for save or restore to not report an error # Otherwise, it will prompt that there are no variables to save saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('./model/') img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read() img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess,'./model/model.ckpt-0') global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] res = sess.run(y, feed_dict={x: img}) print(global_step,sess.run(tf.argmax(res,1)))
|
2. Load graph structure and parameters
Note that the node output tensor is available in both ways by calling the node name, and the node.name property returns the node name.
Simplified release notes:
# Load with graph structure ckpt = tf.train.get_checkpoint_state('./model/') saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') with tf.Session() as sess: saver.restore(sess,ckpt.model_checkpoint_path) # Only load the data, not the graph structure, you can change the value of batch_size, etc. in the new graph # However, it should be noted that a new graph structure needs to be defined before the Saver object is instantiated, otherwise an error will be reported saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state('./model/') saver.restore(sess,ckpt.model_checkpoint_path) |
2. TensorFlow binary model loading method:
This loading method is generally the work of modifying the network models that have been trained by major companies on the Internet
|
# create a new blank image self.graph = tf.Graph() # Blank image as default image with self.graph.as_default(): # binary read model file with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f: # Create a new GraphDef file for temporarily loading the graph in the model graph_def = tf.GraphDef() # GraphDef loads the graph in the model graph_def.ParseFromString(f.read()) # Load the graph in the GraphDef in the blank graph tf.import_graph_def(graph_def,name='') # To get a tensor in the graph, you need to use graph.get_tensor_by_name to add the tensor name # The tensors here can be directly used to evaluate the session's run method # Add a basic knowledge, such as 'conv1' is the node name, and 'conv1:0' is the tensor name, indicating the first output tensor of the node self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name) self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names |