【转载】 TensorFlow - 框架实现中的三种 Graph图结构

原文地址:

https://zhuanlan.zhihu.com/p/31308381

-------------------------------------------------------------------------------------

图(graph)是 tensorflow 用于表达计算任务的一个核心概念。从前端(python)描述神经网络的结构,到后端在多机和分布式系统上部署,到底层 Device(CPU、GPU、TPU)上运行,都是基于图来完成。然而我在实际使用过程中遇到了三对API,

  1. tf.train.Saver()/saver.restore()
  2. export_meta_graph/Import_meta_graph
  3. tf.train.write_graph() / tf.Import_graph_def()

他们都是用于对图的保存和恢复。同一个计算框架,为什么需要三对不同的API呢?他们保存/恢复的图在使用时又有什么区别呢?初学的时候,常常闹不清楚他们的区别,以至常常写出了错误的程序,经过一番研究,在本文中对Tensorflow中围绕Graph的核心概念进行了总结。

Graph

首先介绍一下关于 Tensorflow 中 Graph 和它的序列化表示 Graph_def。在Tensorflow的官方文档中,Graph 被定义为“一些 Operation 和 Tensor 的集合”。例如我们表达如下的一个计算的 python代码,

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2

就会生成相应的一张图,在Tensorboard中看到的图大概如下这样。其中每一个圆圈表示一个Operation(输入处为Placeholder),椭圆到椭圆的边为Tensor,箭头的指向表示了这张图
Operation 输入输出 Tensor 的传递关系。

 这张图所表达的数据流 与 python 代码中所表达的计算是对应的关系(为了称呼方便,我们下面将这张由Python表达式所描述的数据流动关系叫做 Python Graph)。然而在真实的 Tensorflow 运行中,Python 构建的“图”并不是启动一个Session之后始终不变的东西。因为Tensorflow在运行时,真实的计算会被下放到多CPU上,或者 GPU 等异构设备,或者ARM等上进行高性能/能效的计算。单纯使用 Python 肯定是无法有效完成的。实际上,Tensorflow而是首先将 python 代码所描绘的图转换(即“序列化”)成 Protocol Buffer,再通过 C/C++/CUDA 运行 Protocol Buffer 所定义的图。(Protocol Buffer的介绍可以参考这篇文章学习:

GraphDef

从 python Graph中序列化出来的图就叫做 GraphDef(这是一种不严格的说法,先这样进行理解)。而 GraphDef 又是由许多叫做 NodeDef 的 Protocol Buffer 组成。在概念上 NodeDef 与 (Python Graph 中的)Operation 相对应。如下就是 GraphDef 的 ProtoBuf,由许多node组成的图表示。这是与上文 Python 图对应的 GraphDef:

 
 
node {
  name: "Placeholder"     # 注释:这是一个叫做 "Placeholder" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}

node { name: "Placeholder_1" # 注释:这是一个叫做 "Placeholder_1" 的node op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { unknown_rank: true } } } }

node { name: "mul" # 注释:一个 Mul(乘法)操作 op: "Mul" input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1) input: "Placeholder_1" # 作为这个Node的输入 attr { key: "T" value { type: DT_FLOAT } } }
 

以上三个 NodeDef 定义了两个Placeholder和一个Multiply。Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状。Multiply通过 input 属性定义了两个placeholder作为其输入。无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息。其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息。

那么既然 tf.Operation 的序列化 ProtoBuf 是 NodeDef,那么 tf.Variable 呢?在这个 GraphDef 中只有网络的连接信息,却没有任何 Variables呀?没错,Graphdef
中不保存任何 Variable 的信息,所以如果我们从 graph_def 来构建图并恢复训练的话,是不能成功的。
比如以下代码,

with tf.Graph().as_default() as graph:
  tf.import_graph_def("graph_def_path")
  saver= tf.train.Saver()
  with tf.Session() as sess:
    tf.trainable_variables()

其中 tf.trainable_variables() 只会返回一个空的list。Tf.train.Saver() 也会报告 no variables to save。

然而,在实际线上 inference 中,通常就是使用 GraphDef然而,GraphDef中连Variable都没有,怎么存储weight呢?原来GraphDef 虽然不能保存 Variable,但可以保存 Constant 通过 tf.constantweight 直接存储在 NodeDef 里,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef 里面,并将该图导出为 Proto。可以查看以下链接获取更多信息,

tensorflow.org/extend/t

tf.train.write_graph()        tf.Import_graph_def() 就是用来进行 GraphDef 读写的API。那么,我们怎么才能从序列化的图中,得到 Variables呢?这就要学习下一个重要概念,MetaGraph。

猜你喜欢

转载自www.cnblogs.com/devilmaycry812839668/p/12467580.html