一份快速完整的Tensorflow模型保存和恢复教程(译)

原文链接A quick complete tutorial to save and restore Tensorflow models–by ANKIT SACHAN
(英文水平有限,有翻译不当的地方请见谅)

在本教程中,我将介绍:
- tensorflow模型是什么样子的?
- 如何保存一个Tensorflow模型?
- 如何恢复一个Tensorflow模型用于预测/迁移学习?
- 如何导入预训练的模型进行微调和修改?

本教程假设你已经对训练一个神经网络有一定了解。否则请先看这篇教程Tensorflow Tutorial 2: image classifier using convolutional neural network再看本教程。

什么是Tensorflow模型?

当你训练好一个神经网络后,你会想保存好你的模型便于以后使用并且用于生产。因此,什么是Tensorflow模型?Tensorflow模型主要包含网络设计(或者网络图)和训练好的网络参数的值。所以Tensorflow模型有两个主要的文件:

a) Meta图:
Meta图是一个协议缓冲区(protocol buffer),它保存了完整的Tensorflow图;比如所有的变量、运算、集合等。这个文件的扩展名是.meta

b) Checkpoint 文件
这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt。但是, 从0.11版本开始,Tensorflow对改文件做了点修改,checkpoint文件不再是单个.ckpt文件,而是如下两个文件:

mymodel.data-00000-of-00001
mymodel.index

其中, .data文件包含了我们的训练变量。除此之外,还有一个叫checkpoint的文件,它保留了最新的checkpoint文件的记录。

总结一下,对于0.10之后的版本,tensorflow模型包含以下文件:

model files
但对于0.11之前的版本,只包含三个文件:

inception_v1.meta
inception_v1.ckpt
checkpoin

现在我们已经知道Tensorflow模型是什么样子的,让我们继续学习如何保存模型。

保存Tensorflow模型

假如你正在训练一个用于图像分类的卷积神经网络(training a convolutional neural network for image classification)。通常你会先观察损失和准确率,一旦发现网络收敛,就可以手动停止训练过程或者直接训练固定迭代次数。当训练完成后,我们想要保存所有的变量和网络图便于以后使用。因此在Tensorflow中, 为了保存网络图和所有参数的值,我们应该创建tf.train.Saver()这个类的一个对象。

saver = tf.train.Saver()

记住Tensorflow变量只有在会话(session)中才能激活。因此,你需要在会话中调用你刚创建的对象的保存方法。

saver.save(sess, "my-test-model")

这里,sess是一个session对象,“my-test-model”是你的模型名字。让我们看一个完整的例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

如果我们要在1000次迭代后保存模型,我们应该在调用保存方法时传入步数计数:

saver.save(sess, "my_test_model", global_step=1000)

这会在模型名称后加一个“-1000”并且会创建如下文件:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

假设在训练过程中,我们要每1000次迭代保存我们的模型,因此.meta文件会在第一次(1000次迭代)时创建,我们并不需要之后每1000次迭代都保存一遍这个文件(我们在2000,3000…迭代时都不需要保存这个文件,因为这个文件始终不变)。我们只需要保存这个模型供以后使用,因为模型图不会变化。所以,当我们不想重写meta图的时候,我们这样写:

saver.save(sess, "my-model", global_step=step, write_meta_graph=False)

如果你只想保留4个最新的模型并且在训练过程中每过2小时保存一次模型,你可以使用max_to_keep和keep_checkpoint_every_n_hours,就像这样:

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

注意,如果我们在tf.train.Saver()中不指定任何东西,它将保存所有的变量。要是我们不想保存所有的变量而只是一部分变量。我们可以指定我们想要保存的变量/集合。当创建tf.train.Saver()对象的时候,我们给它传递一个我们想要保存的变量的字典列表。我们来看一个例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

当需要的时候,这个代码可以用来保存Tensorflow图中的指定部分。

导入预训练模型

如果你想要用其他人预训练的模型进行微调,需要做两件事:

a) 创建网络
你可以写python代码来手动创建和原来一样的模型。但是,想想看,我们已经将原始网络保存在了.meta文件中,可以用tf.train.import()函数来重建网络:

saver = tf.train.import_meta_graph("my_test_model-1000.meta")

记住,import_meta_graph函数将只将定义在.meta文件中的网络添加到当前的图上。因此,它虽然帮你创建了额图/网络,但我们还是需要导入我们在这个图上训练好的模型的参数。

b) 导入参数
我们可以调用由tf.train.Saver()创建的对象saver中的restore方法来恢复网络中的参数。

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))

这样,张量的值(如w1和w2)就被恢复并且可以访问了:

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.

现在你已经理解了如何保存和导入Tensorflow模型。在下一节,我会介绍一个实际应用即导入任何预训练好的模型。

使用恢复的模型

现在你已经理解如何保存和恢复Tensorflow模型,我们来写一个实际的示例来恢复任何预训练的模型并用它来预测、微调或者进一步训练。无论你什么时候用Tensorflow,你都会定义一个网络,它有一些样本(训练数据)和超参数(如学习率、迭代次数等)。通常用一个占位符(placeholder)来将所有的训练数据和超参数输入给网络。下面我们用占位符建立一个小型网络并保存它。注意,当网络被保存的时候,占位符中的值并没有被保存。

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

当我们想要恢复这个网络的时候,我们不仅需要恢复图和权重,还需要准备一个新的feed_dict来将训练数据输入到网络中。我们可以通过graph.get_tensor_by_name方法来引用这些保存的运算和占位符变量。

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")

## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我们只是想用不同的数据运行相同的网络,你可以方便地用feed_dict将新的数据送到网络中。

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

要是你想在原来的计算图中通过添加更多的层来增加更多的运算并且训练。当然也可以实现,如下:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

但是,我们能够只恢复原来图中的一部分然后添加一些其它层来微调吗?当然可以,只要通过graph.get_tensor_by_name()方法来获取原网络的部分计算图并在上面继续建立新计算图。这里给出了一个实际的例子。我们用meta图导入了一个预训练的vgg网络,然后将最后一层的输出个数改成2用于微调新的数据。

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

num_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

希望本文能够让你清楚地理解Tensorflow是如何被保存和微调的。请在评论区自由分享你的问题或者疑问。
文中的代码亲测可用,此处再次附上原文链接A quick complete tutorial to save and restore Tensorflow models–by ANKIT SACHAN

猜你喜欢

转载自blog.csdn.net/sinat_34474705/article/details/78995196