Tensorflow-GraphDef、MetaGraph、CheckPoint

Tensorflow框架实现的三种图

参考原文:http://www.360doc.com/content/17/1123/18/7669533_706522939.shtml

==========================

Graph:

Tensorflow所运行的代码,或者说用python代码表达的计算,所描述的对象实际上就是一张计算图,包含了各个运算节点和用于计算的张量。而Graph_def是图Graph的序列表示。python所描述的这个graph,并不是在运行Tensorflow,启动一个Session后就保持不变的,因为Tensorflow在实际运行过程中,真实的计算是会被下放到多CPU,或者GPU、ARM等异构设备上进行高性能计算的,如果仅仅单纯地使用python肯定是无法有效地完成计算的。所以Tensorflow的实际计算过程是这样的:

Tensorflow先将python代码所描绘的图进行转换,转化成Protocol Buffer(即序列化),再通过C/C++/CUDA运行Protocol Buffer所定义的图。

(Protocol Buffer:

https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/

Tensorflow实战Google深度学习框架Chapter 2 )

==========================

GraphDef:

从python代码描述的Graph中序列化得到的图就叫做GraphDef。GraphDef可以理解为一种数据结构。GraphDef是由许多叫做NodeDef的Protocol Buffer组成。其中NodeDef也可以理解为是数据结构。(实际上从数据结构的角度上就很好理解这些内容)。GraphDef强调的是操作节点之间的联系。Tensorflow中通过NodeDef中的input这一attribute来定义Node之间的连接信息。

在概念上,NodeDef与python代码描绘的Graph中的操作运算节点Operation相对应。可知GraphDef中只有NodeDef,也就是说只有python描述的Graph中的Operation,并没有Variable。所以这也反映出了GraphDef这个图强调的是python描述的Graph的连接信息,并不保存Variable的相关信息(注意并不是所有Tensor的相关信息都不保存,constant类型的Tensor的相关信息就会在GraphDef中保存)。所以如果要从graph_def来构建图并恢复训练的话,是不一定能成功的,因为缺少了例如Variable等这些Tensor。

在实际线上Inference中,通常使用的是GraphDef。虽然GraphDef不会保存Variable这类Tensor,但是会保存constant这类Tensor,所以还是可以用来存储例如weights这些参数的。在Tensorflow 1.3.0版本中提供了一套叫做freeze_graph的工具来自动地将python所描述的Graph中的Variable替换成constant存储在GraphDef中,并将该Graph导出为Proto.

(freeze_graph:

https://www.tensorflow.org/extend/tool_developers/

tf.train.writer_graph()/tf.import_graph_def()就是用来进行GraphDef读写的API。

可知如果仅仅从GraphDef中是无法得到Variable的。

==========================

MetaGraph:

在GraphDef中无法得到Variable,而通过MetaGraph可以得到。

MetaGraph的官方解释:一个MetaGraph是由一个计算图和其相关的元数据构成的。其包含了用于继续训练、实施评估和(在已经训练好的Graph图上)做前向推断的信息。

https://www.tensorflow.org/versions/r1.1/programmers_guide/

MetaGraph在具体实现上,就是一个MetaGraphDef(同样是由Protocol Buffer来定义的)。其中包含了四种主要的信息:

MetaInfoDef: 存放了一些元信息,例如版本和其他用户信息;

GraphDef: MetaGraph的核心内容之一;

SaverDef: 图的Saver信息,例如最多同时保存的checkpoint数量,需要保存的Tensor名字等,但并不保存Tensor中的实际内容;

CollectionDef: 任何需要特殊注意的python对象,需要特殊的标注以方便import_meta_graph后取回,例如”train_op”,”prediction”等等。

其中着重介绍CollectionDef,其为Collection对应的Protocol Buffer。

集合collection是为了方便用户对图中的操作和变量进行管理而被创建的一个概念,通过一个string类型的key来对一组python对象进行命名的集合。这个key可以是Tensorflow在内部定义的一些key,也可以是用户自定义的名字,但是注意是string类型。它有一点命名空间的意思,将变量收录进某一个集合collection中。

Tensorflow内部定义了许多标准的key,全部定义在了tf.GraohKeys这个类当中。其中有一些是常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES等等。tf.trainable_variables()和tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)是等价的;tf.global_variables()和tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)是等价的。

对于用户定义的key:

pred = model_network(X)

loss = tf.readuce_mean(…, pred, …)

train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

对于这一段对训练过程定义的代码,用户希望特别关注pred, loss, train_op这几个操作,那么就可以使用如下代码,将这几个变量加入集合collection中去。令这个集合名为”training_collection”:

tf.add_to_collection(‘training_collection’, pred)

tf.add_to_collection(‘training_collection’, loss)

tf.add_to_collection(‘training_collection’, train_op)

并且可以通过Train_collect = tf.get_collection(‘training_collection’)得到一个python的list,list中的元素就是加入集合的几个变量pred, loss, train_op。这通常是为了在一个新的Session中打开这张Graph时,方便我们获取想要的操作节点Operation。例如可以通过tf.get_collection()得到train_op,然后通过sess.run(train_op)来进行训练,而无需重新构建loss和Optimizer。

通过tf.export_meta_graph()保存Graph,得到MetaGraph,并通过add_to_collection()将操作Operation加入collection中:

with tf.Session() as sess:

pred = model_network(X)

loss = tf.readuce_mean(…, pred, …)

train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

tf.add_to_collection(‘training_collection’, train_op)

Meta_graph_def = tf.train.export_meta_graph(tf.get_default_graph(), ‘my_graph.meta’)

通过import_meta_graph将MetaGraph恢复,同时初始化为本Session的default Graph,并通过get_collection重新获得train_op,以及通过train_op开始一段训练(sess.run())。

从MetaGraph中恢复构建的图Graph是可以被训练的。

https://www.tensorflow.org/api_guides/python/meta_graph

需要特殊说明的是,MetaGraph中虽然包含Variable的信息,但是没有Variable的实际值。所以从MetaGraph中恢复的图Graph,训练都是从随机初始化的值开始的,训练中的Variable 的实际值都保存在checkpoint文件中,如果要从之前训练的状态继续恢复训练,就需要从checkpoint中restore。

tf.export_meta_graph()/tf.import_meta_graph()即为用来进行MetaGraph读写的API。tf.train.saver.save()在保存checkpoint的同时也会保存MetaGraph,但是在恢复图时,tf.train.saver.restore()只恢复Variable。如果要从MetaGraph中恢复图Graph,需要使用tf.import_meta_graph()。这其实是为了方便用户,因为有时我们不需要从MetaGraph中恢复图Graph,而仅仅需要在python中构建NN的Graph,并恢复对应的Variable。

==========================

CheckPoint:

CheckPoint中全面保存了训练某时间截面的信息,包括参数、超参数、梯度等等。tf.train.Saver()/tf.saver.restore()则能够完整地保存和恢复神经网络的训练。CheckPoint分为两个文件保存Variable的二进制信息:ckpt文件保存了Variable的二进制信息,index文件用于保存ckpt文件中对应Variable的偏移量信息。

==========================

总结:

Tensorflow三种API所保存和恢复的图Graph是不一样的。这三种图是从Tensorflow框架设计的角度出发定义的。简而言之,Tensorflow在前段python中构建图Graph,并且通过将该图序列化到Protocol Buffer得到GraphDef,以方便在后端运行。在这个过程中,图的保存、恢复、运行都通过ProtoBuf来实现。GraphDef、MetaGraph以及Variable、Collection、Saver等都有对应的ProtoBuf定义。ProtoBuf的定义也决定了用户能对图进行的操作。例如用户只能找到Node的前一个Node,却无法得知自己的输出会被哪个Node接受。

猜你喜欢

转载自blog.csdn.net/weixin_39721347/article/details/86171990