TensorFlow技术内幕(十三):模型保存与恢复

模型训练过程中,我们希望在训练一段时间或一定次数之后,保存模型的当前状态,用于实验分析,或故障恢复,又或则是提供给线上服务使用。

Tensorflow中模型的状态包括两个方面,首先是模型的结构,其次是参数和参数的当前值。

本文介绍一下tensoflow中模型的这两类信息是如何保存和恢复的。

模型结构

模型的结构信息是静态信息,在构建完成后一般不会再变化,保存的过程比较简单直接,首先将结构信息收集并写到MetaGraphDef对象中,然后将MetaGraphDef对象序列化后存储到外部存储.

MetaGraphDef是protobuf定义的message, 囊括了恢复模型结构的所有必要信息,简单介绍下面几类信息:

名称 信息类别
MetaInfoDef 主要包括模型中使用的所有operation所对应的OpDef定义信息
GraphDef 模型Graph结构信息,也就是模型中Node的信息
SaverDef Saver的配置信息,包括文件名,是否分片,checkpoint保存周期等等
CollectionDef collection name到collection的映射关系信息(简单来说,collection定义了一组节点,或则变量的集合)

下面是结构信息保存过程的时序图:

在这里插入图片描述

模型参数

参数的个数一般是固定的,参数的值是随着训练的过程不断变化的。那么如果保存参数信息呢?

一般来说,一个比较朴素直接的方法就是遍历所有待保存的参数,获取参数的值并存储到外部存储中去。那么这个朴素版本的方案有没有什么问题呢?

其实,在设计方案的时候有一个原则,就是数据规模决定实现的方法。在数据量不大的情况下,这个保存参数值方法并没什么问题。但是当数据规模大了以后,这个方法就有很严重的效率问题,并且在超大规模的训练中,这个方法实际上是不可行的。

为什呢?我们知道,在tensoflow集群方式训练的过程中,集群会由若干个worker节点和若干个ps(parameter server)节点组成, worker节点承担主要的运算量,ps节点共同负责存储和更新所有的参数。

上述的参数保存方法中,保存过程的执行者是worker节点,它会将所有参数的参数值从ps节点拉取到本地,然后从本地写到外部存储中去,这样一来,如果参数的纬度比较大,参数的传输就会占用很大的带宽,造成性能问题,并且如果参数纬度更大,超过单机内存容量,那么这个方法就会耗尽worker的内存。

显然这个方法不是一个很好的方案,那么tensorflow是如何保存参数的呢?

tensorflow参数保存的功能,同样也实现在tf.train.Saver类里,通过两个步骤来实现:

  • 第一步:在模型Graph中添加Save Node,并将所有需要保存值的Node(一般来说也就是Variable Node)作为Save Node的输入节点。
  • 第二步:在需要保存参数值的时刻,运行Save Node,Save Node调用Save Op,将输入全部写到外部存储中。

这样的设计就避免了参数从ps到worker的传输,因为Save Node一般会分配到参数所在的ps上。

下面分析一下参数保存的具体过程,其中Save Node的添加发生在tf.train.Saver的构造函数内:

在这里插入图片描述

注意,上面的时序图中包括了Restore Node的添加过程,实际上也是这样的,Save和Restore都是在Saver的构造函数里添加的。

完成第一步之后,第二步就是在需要保存参数的时刻,调用Saver的Save方法,运行Save Node,时序图如下:

在这里插入图片描述

我们可能还比较关心的是,上面Saver构造函数里添加的Save Node和Restore Node的实现。我们以Saver Node为例介绍一下

SaveOp

Saver Node调用的是Save操作, Save操作在CPU上的操作核是SaveV2, 它继承自OpKnernel, 保存参数的功能在函数Compute中,保存的时序图如下:
在这里插入图片描述

SaveV2有三个固定的输入,分别是prefix, tensor_names和shape_and_sclices, 分别表示保存目录、tensor名、tensor形状和切片信息;剩下的若干个输入都是需要保存的tensor的值,数量等于tensor_names数组的长度。

tensor形状和切片信息存在的原因是,作为SaveV2输入的tensor,可能是原始tensor的一部分,也就是所谓的原始tensor的一个分片,SaveV2支持tensor分片单独存储。例如当我们声明Partion Variable的时候,就会出现tensor分片的情况。

总结

上面我们详细分析了模型的保存过程,模型恢复的过程就是存储过程的逆向过程,了解了存储了过程之后,掌握恢复过程就比较简单了,这里就不再展开介绍了。

发布了52 篇原创文章 · 获赞 105 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/gaofeipaopaotang/article/details/88777678
今日推荐