Tensorflow模型的保存和加载

刚接触深度学习,Tensorflow模型的保存和加载尚不清楚,根据教程的翻译做一记录,不当之处敬请指正。

原文地址:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

在本教程中,将作出如下讲解:

    I.   Tensorflow模型是什么样的?

    II.  如何保存一个训练好的Tensoflow模型?

    III. 如何加载一个先前保存的Tensorflow模型?

    IV. 如何对pretrained模型进行fine-tuning和修改?

本教程假设你已经对神经网络的训练有了一定的基础。

1.Tensorflow的模型到底是什么样的?:

在训练好一个神经网络模型之后,我们通常希望将它保存下来,方便以后的使用。那么,什么是Tensorflow模型呢?Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等。所以,Tensorflow模型有两个主要的文件:

a) Meta graph:

这是一个协议缓冲区(protocol buffer),它完整地保存了Tensorflow图;即所有的变量、操作、集合等。此文件以 .meta 为拓展名。

b) Checkpoint 文件:

这是一个二进制文件,包含weights、biases、gradients 和其他所有变量的值。此文件以 .ckpt 为扩展名. 但是,从Tensorflow 0.11版本之后做出了一些改变。现在,不再是单一的 .ckpt 文件,而是一下两个文件:


.data文件包含了我们的训练变量,稍后再说。

另外,Tensorflow还有一个名为 checkpoint 的文件,仅用于保存最新checkpoint文件保存的记录。

总之,版本0.10以上的Tensorflow模型如下所示:


注意,Tensorflow 0.11版本之前的模型只有三个文件:


现在,我们知道了Tensorflow模型是什么样的,下面我们学习如何保存一个模型。

2. 保存一个Tensorflow模型:

比方说你正在训练一个卷积神经网络用于图像分类,你会关注于loss值和accuracy. 一旦你看到网络converged, 你就可以手工停止训练或设置固定的训练迭代次数。训练完成之后,我们想把所有变量值和网络图保存到文件中方便以后使用。所以,为了保存Tensorflow中的图和所有参数的值,我们创建一个tf.train.Saver()类的实例。

saver = tf.train.Saver()

别忘了Tensorflow变量仅存在于session内,所以你必须在session内进行保存,可通过调用创建的saver对象的sava方法实现。


其中,sess是session对象,‘my-test-model’是你对自己模型的命名。让我们看一个完整的例子:


如果我们想在迭代1000次后保存模型,我们需在对应的步数之后调用 sava 方法:

saver.save(sess, 'my_test_model',global_step=1000)

这只是在模型名字后面附加上‘-1000’,同时将创建下列文件:


比方说,在训练过程中,我们想每迭代1000次就保存模型一次。在第一次保存模型时会创建 .meta 文件(第1000次迭代),并且我们无需每次都重复创建(也就是无需在第2000、3000...或其他迭代时保存.meta文件)。因为图结构不变,我们仅需保存迭代后的参数值。因此,当我们不想写入meat-graph时,使用:


如果你想保存最近的4个模型并且每训练两个小时保存一次,可以使用 max_to_keep 和 keep_checkpoint_every_n_hours,如下所示:


注意,如果我们没有在tf.train.Saver()中指定任何参数,它会保存所有变量。如果我们不想保存全部变量而只是想保存一部分的话,我们可以指定想保存的variables/collections.在创建tf.train.Saver实例时,我们将它传递给我们想要保存的变量的列表或字典。看一个例子:


这可以用于保存Tensorflow图的特定部分。

3.引入一个pretrained模型:

如果你想引入其他的预先训练的模型来fine-tuning,需要做两件事:

a) 创建网络:

你可以通过写python代码手工创建原来模型的网络。或者,通过我们之前创建的 .mate 文件进行网络创建,使用tf.train.import()函数实现,如:saver = tf.train.import_meta_graph('my_test_model-1000.meta')

记住,import_meta_graph会将.meta文件中保存的网络加载到当前网络中,这会创建一个graph/network,但我们仍需加载已训练的各参数值。

b) 加载参数:

我们可以通过调用tf.train.Saver()类的restore方法来加载参数。


这样,像w1、w2这些tensors的值就加载进来了并且可以进行访问:


现在,你已经知道如何保存和加载Tensorflow模型。下面给出了一个使用练习加载预先训练的模型。。。。


猜你喜欢

转载自blog.csdn.net/Albert201605/article/details/79994331