2018/4/26(二)

一个快速完整的教程来保存和恢复Tensorflow    

http://cv-tricks.com/tensorflow-tutorial/-saverestore-tensorflow-models-quick-complete-tutorial/

在这个Tensorflow教程中,我将解释:

  1. Tensorflow模型是怎样的?
  2. 如何保存Tensorflow模型?
  3. 如何恢复预测/传输学习的Tensorflow模型
  4. 如何使用导入的预训练模型进行微调和修改

本教程假定您对训练神经网络有一些想法。否则,请按照本教程进行操作并返回此处。

1.什么是Tensorflow模型?

在训练完神经网络之后,您需要将其保存以供将来使用并部署到生产环境。那么,什么是Tensorflow模型?Tensorflow模型主要包含我们已经培训的网络设计或者图形和网络参数的值。因此,Tensorflow模型有两个主要文件:

a)元图:

这是一个保存完整Tensorflow图的协议缓冲区; 即所有变量,操作,集合等。该文件具有.meta扩展名。

b)检查点文件:

这是一个二进制文件,其中包含权重,偏差,梯度和所有其他变量的所有值。这个文件有一个扩展名。CKPT。但是,Tensorflow已经从版本0.11改变了这一点。现在,而不是单个.ckpt文件,我们有两个文件:

.data文件是包含我们的训练变量的文件,我们将继续。

除此之外,Tensorflow还有一个名为checkpoint的文件,它只保存最新检查点文件的记录。

因此,总而言之,版本大于0.10的Tensorflow模型如下所示:

Tensorflow教程

而0.11之前的Tensorflow模型仅包含三个文件:

现在我们知道了Tensorflow模型的外观,我们来学习如何保存模型。

2.保存Tensorflow模型:

假设您正在训练用于图像分类的卷积神经网络作为一种标准做法,您需要关注损失和准确性数字。一旦您看到网络已经融合,您可以手动停止训练,或者您将运行固定数量的时期训练。培训完成后,我们希望将所有变量和网络图保存到一个文件以供将来使用。因此,在Tensorflow中,您想要保存要为其创建tf.train.Saver()类实例的所有参数的图形和值。

saver = tf.train.Saver()

请记住,Tensorflow变量只在会话中存在。因此,您必须通过调用刚创建的保存程序对象上的save方法将模型保存在会话中。

这里,sess是会话对象,而'my-test-model'是你想要给你的模型的名字。我们来看一个完整的例子:

如果我们在1000次迭代后保存模型,我们将通过传递步数来调用save:

saver.save(sess, 'my_test_model',global_step=1000)

这只会将'-1000'附加到型号名称上,并创建以下文件:

比方说,在训练时,我们在每1000次迭代后保存模型,所以.meta文件是第一次创建(在第1000次迭代中),我们不需要每次都重新创建.meta文件(所以,将.meta文件保存为2000,3000 ..或任何其他迭代)。我们只保存模型以进一步迭代,因为图形不会改变。因此,当我们不想编写元图时,我们使用这个:

如果您只想保留4个最新型号,并且想要在训练期间每2小时保存一个型号,则可以使用max_to_keep和keep_checkpoint_every_n_hours。

 

请注意,如果我们没有在tf.train.Saver()中指定任何内容,它会保存所有变量如果我们不想保存所有的变量而只保存其中的一部分,会怎样呢?我们可以指定我们想要保存的变量/集合。在创建tf.train.Saver实例时,我们将它传递给我们想要保存的变量的列表或字典。我们来看一个例子:

这可用于在需要时保存Tensorflow图的特定部分。

3.导入预先训练的模型:

如果你想使用别人的预先训练好的模型进行微调,你需要做两件事情:

a)创建网络:

您可以通过编写Python代码来创建网络,以手动创建每个图层作为原始模型。然而,如果你仔细想想,我们已经将网络保存在.meta文件中,我们可以使用tf.train.import()函数来重新创建网络,如下所示:saver = tf.train.import_meta_graph('my_test_model-1000.meta')

请记住,import_meta_graph会将.meta文件中定义的网络附加到当前图形中。因此,这将为您创建图形/网络,但我们仍然需要加载我们在此图上训练过的参数的值。

b)加载参数:

我们可以通过调用该保存程序中的恢复来恢复网络的参数,该程序是tf.train.Saver()类的一个实例。

在此之后,像w1和w2这样的张量值已经恢复并可以被访问:

所以,现在您已经了解了Tensorflow模型的保存和导入工作原理。在下一节中,我已经描述了上述的实际用法来加载任何预先训练好的模型。

4.使用恢复的模型

既然您已经了解了如何保存和恢复Tensorflow模型,那么让我们开发一个实用指南,以恢复任何预先训练好的模型,并将其用于预测,微调或进一步培训。无论何时使用Tensorflow,您都可以定义一个图表,其中包含示例(训练数据)和一些超参数,例如学习速率,全局步长等。使用占位符提供所有训练数据和超参数是一种标准做法。我们使用占位符构建一个小型网络并保存它。请注意,保存网络时,不会保存占位符的值。

现在,当我们想要恢复它时,我们不仅需要恢复图形和权重,还要准备一个新的feed_dict,将新的训练数据馈送到网络。我们可以通过graph.get_tensor_by_name()方法获得对这些保存的操作和占位符变量的引用

如果我们只想用不同的数据运行同一个网络,只需将新数据通过feed_dict传递给网络即可。

如果您想通过添加更多图层并添加更多图层来为图表添加更多操作。当然你也可以这样做。看这里:

但是,您是否可以恢复部分旧图形和插件以进行微调?当然,您可以通过graph.get_tensor_by_name()方法访问相应的操作,并在其上创建图形。这是一个真实世界的例子。在这里,我们使用元图加载一个vgg预训练网络,并将最后一层中的输出数量更改为2,以便用新数据进行微调。



猜你喜欢

转载自blog.csdn.net/wangweiijia/article/details/80099910