save-restore保存与读取

saver实例代码:

import tensorflow as tf
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')  
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')  

init= tf.global_variables_initializer()  

saver = tf.train.Saver()  

with tf.Session() as sess:  
   sess.run(init)  
   save_path = saver.save(sess, "/home/violet/aitest/ResNet/logs/test/save_net.ckpt")  
   print("Save to path: ", save_path)  

restore实例代码:

# restore variables  
# redefine the same shape and same type for your variables  
import tensorflow as tf
import numpy as np
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")  
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")  

# not need init step  

saver = tf.train.Saver()  
with tf.Session() as sess:  
    saver.restore(sess, "/home/violet/aitest/ResNet/logs/test/save_net.ckpt")  
    print("weights:", sess.run(W))  
    print("biases:", sess.run(b))  

一次 saver.save() 后可以在文件夹中看到新增的四个文件:
这里写图片描述

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。简单理解就是权重等参数被保存到 .chkp.data 文件中,以字典的形式;图和元数据被保存到 .chkp.meta 文件中,可以被 tf.train.import_meta_graph 加载到当前默认的图。

根据已有模型进行微调

(1)利用tf.train.Saver()从checkpoint恢复模型

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to restore all the variables.
restorer = tf.train.Saver()

# Add ops to restore some variables.
restorer = tf.train.Saver([v1, v2])

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

(2)部分恢复模型参数

# Create some variables.
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...

# Get list of variables to restore (which contains only 'v2'). These are all
# equivalent methods:
variables_to_restore = slim.get_variables_by_name("v2")
# or
variables_to_restore = slim.get_variables_by_suffix("2")
# or
variables_to_restore = slim.get_variables(scope="nested")
# or
variables_to_restore = slim.get_variables_to_restore(include=["nested"])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["v1"])

# Create the saver which will be used to restore the variables.
restorer = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model

(3)当图的变量名与checkpoint中的变量名不同时,恢复模型参数

当从checkpoint文件中恢复变量时,Saver在checkpoint文件中定位到变量名,并且把它们映射到当前图中的变量中。之前的例子中,我们创建了Saver,并为其提供了变量列表作为参数。这时,在checkpoint文件中定位的变量名,是隐含地从每个作为参数给出的变量的var.op.name而获得的。这一方式在图与checkpoint文件中变量名字相同时,可以很好的工作。而当名字不同时,必须给Saver提供一个将checkpoint文件中的变量名映射到图中的每个变量的字典,例子见下:

# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'
def name_in_checkpoint(var):
  return 'vgg16/' + var.op.name

# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'
def name_in_checkpoint(var):
  if "weights" in var.op.name:
    return var.op.name.replace("weights", "params1")
  if "bias" in var.op.name:
    return var.op.name.replace("bias", "params2")

variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
  # Restore variables from disk.
  restorer.restore(sess, "/tmp/model.ckpt")

(4)在一个不同的任务上对网络进行微调

比如我们要将1000类的imagenet分类任务应用于20类的Pascal VOC分类任务中,我们只导入部分层,见下例:

image, label = MyPascalVocDataLoader(...)
images, labels = tf.train.batch([image, label], batch_size=32)

# Create the model
predictions = vgg.vgg_16(images)

train_op = slim.learning.create_train_op(...)

# Specify where the Model, trained on ImageNet, was saved.
model_path = '/path/to/pre_trained_on_imagenet.checkpoint'

# Specify where the new model will live:
log_dir = '/path/to/my_pascal_model_dir/'

# Restore only the convolutional layers:
variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)

# Start training.
slim.learning.train(train_op, log_dir, init_fn=init_fn) 

原文链接:
https://www.cnblogs.com/bmsl/p/dongbin_bmsl_01.html

猜你喜欢

转载自blog.csdn.net/violethan7/article/details/80016862