tensorflow checkpoint loading [transfer]

Original link: https://www.cnblogs.com/hellcat/p/6925757.html

1. TensorFlow general model loading method:

The checkpoint file will record the save information, through which the most recently saved model can be located:

1
2
ckpt = tf.train.get_checkpoint_state( './model/' )
print (ckpt.model_checkpoint_path)

 

.meta file saves the current graph structure

The .index file holds the current parameter name

The .data file holds the current parameter values

The tf.train.import_meta_graph function will load the graph structure after giving the path of model.ckpt-n.meta and return the saver object

The tf.train.Saver function will return the saver object that loads the default graph

After the saver.restore function gives the path of model.ckpt-n, it will automatically find the parameter name-value file for loading

 

1. Do not load the graph structure, only load the parameters

Since we actually save the values ​​of Variable variables in our parameters, other parameter values ​​(such as batch_size), etc., we may want to modify them when restoring, but the graph structure is generally determined during training, so we can use tf .Graph().as_default() creates a new default graph (recommended to use the context environment), and uses this new graph to modify the size of the parameter values ​​unrelated to variables, so as to achieve the purpose.

 
'''
Load the model saved using the original network to the redefined graph
Models can be loaded using python variable names or node names
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
    y = Net.inference_1(x, N_CLASS=5, train=False)
 
    with tf.Session() as sess:
        # There must be a Variable in front of the program for save or restore to not report an error
        # Otherwise, it will prompt that there are no variables to save
        saver = tf.train.Saver()
 
        ckpt = tf.train.get_checkpoint_state('./model/')
        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
        img = sess.run(tf.expand_dims(tf.image.resize_images(
            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,'./model/model.ckpt-0')
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            res = sess.run(y, feed_dict={x: img})
            print(global_step,sess.run(tf.argmax(res,1)))

 

  2. Load graph structure and parameters


Note that the node output tensor is available in both ways by calling the node name, and the node.name property returns the node name.

Simplified release notes:

 
# Load with graph structure
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)
             
# Only load the data, not the graph structure, you can change the value of batch_size, etc. in the new graph
# However, it should be noted that a new graph structure needs to be defined before the Saver object is instantiated, otherwise an error will be reported
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)

2. TensorFlow binary model loading method:

This loading method is generally the work of modifying the network models that have been trained by major companies on the Internet


# create a new blank image
self.graph = tf.Graph()
# Blank image as default image
with self.graph.as_default():
    # binary read model file
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # Create a new GraphDef file for temporarily loading the graph in the model
        graph_def = tf.GraphDef()
        # GraphDef loads the graph in the model
        graph_def.ParseFromString(f.read())
        # 在空白图中加载GraphDef中的图
        tf.import_graph_def(graph_def,name='')
        # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
        # 这里的张量可以直接用于session的run方法求值了
        # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325866530&siteId=291194637