Original link: https://www.cnblogs.com/hellcat/p/6925757.html
1. TensorFlow general model loading method:
The checkpoint file will record the save information, through which the most recently saved model can be located:
1
2
|
ckpt
=
tf.train.get_checkpoint_state(
'./model/'
)
print
(ckpt.model_checkpoint_path)
|
.meta file saves the current graph structure
The .index file holds the current parameter name
The .data file holds the current parameter values
The tf.train.import_meta_graph function will load the graph structure after giving the path of model.ckpt-n.meta and return the saver object
The tf.train.Saver function will return the saver object that loads the default graph
After the saver.restore function gives the path of model.ckpt-n, it will automatically find the parameter name-value file for loading
1. Do not load the graph structure, only load the parameters
Since we actually save the values of Variable variables in our parameters, other parameter values (such as batch_size), etc., we may want to modify them when restoring, but the graph structure is generally determined during training, so we can use tf .Graph().as_default() creates a new default graph (recommended to use the context environment), and uses this new graph to modify the size of the parameter values unrelated to variables, so as to achieve the purpose.
''' Load the model saved using the original network to the redefined graph Models can be loaded using python variable names or node names ''' import AlexNet as Net import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3]) y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess: # There must be a Variable in front of the program for save or restore to not report an error # Otherwise, it will prompt that there are no variables to save saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('./model/') img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read() img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess,'./model/model.ckpt-0') global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] res = sess.run(y, feed_dict={x: img}) print(global_step,sess.run(tf.argmax(res,1)))
|
2. Load graph structure and parameters
Note that the node output tensor is available in both ways by calling the node name, and the node.name property returns the node name.
Simplified release notes:
# Load with graph structure ckpt = tf.train.get_checkpoint_state('./model/') saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') with tf.Session() as sess: saver.restore(sess,ckpt.model_checkpoint_path) # Only load the data, not the graph structure, you can change the value of batch_size, etc. in the new graph # However, it should be noted that a new graph structure needs to be defined before the Saver object is instantiated, otherwise an error will be reported saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state('./model/') saver.restore(sess,ckpt.model_checkpoint_path) |
2. TensorFlow binary model loading method:
This loading method is generally the work of modifying the network models that have been trained by major companies on the Internet
|
# create a new blank image self.graph = tf.Graph() # Blank image as default image with self.graph.as_default(): # binary read model file with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f: # Create a new GraphDef file for temporarily loading the graph in the model graph_def = tf.GraphDef() # GraphDef loads the graph in the model graph_def.ParseFromString(f.read()) # 在空白图中加载GraphDef中的图 tf.import_graph_def(graph_def,name='') # 在图中获取张量需要使用graph.get_tensor_by_name加张量名 # 这里的张量可以直接用于session的run方法求值了 # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量 self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name) self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names |