Tensorflow入门教程三—计算图

计算图是Tensorflow一个基本概念,Tensorflow中的所有计算都会被转化为计算图上的节点。Tensorflow中的每一个计算都是计算图上的一个节点,而节点之间的边描述了计算之间的依赖关系。如下图所示
这里写图片描述
MatMul矩阵相乘运算依赖张量w,x。

Tensorflow的程序可以分为两个阶段,第一阶段需要定义计算图中所有的计算,第二节阶段为执行计算。定义计算的样例如下

import tensorflow as tf
a = tf.constant([[1,2,2],[1,2,3]])
b = tf.constant([[1,2,2],[1,2,3]])
result = a + b

通过tf.get_default_graph()函数可以获取当前默认的计算图。以下代码可以查看当前运算是否属于计算图

#a.graph可以查看计算a的图属性。
print(a.graph is tf.get_default_graph())

Tensorflow支持通过tf.Graph()函数来生成新的计算图。不同计算图上的张量和运算不会共享。下面代码示意了如何在不同计算图中定义和使用变量。

import tensorflow as tf
g1 = tf.Graph() #在图1中定义扁郎v,初始化为0
with g1.as_default():
    v = tf.Variable(1,name='v')

g2 = tf.Graph() #在图2中定义扁郎c,初始化为1
with g2.as_default():
    c = tf.Variable(0,name='v')
with tf.Session(graph = g1) as sess: #读取图1中的v
    tf.global_variables_initializer().run()
    print(sess.run(v))

with tf.Session(graph = g2) as sess:#读取图2中的c
    tf.global_variables_initializer().run()
print(sess.run(c))

Tensorflow中的计算图不仅仅不可以用来隔离张量和计算,它还提供了管理张量和计算的机制。计算图可以通过tf.Graph().device()函数来指定运行计算的设备。以下程序将加法跑在GPU上。

import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
    a = tf.constant([1,2])
    b = tf.constant([2,3])
    with g1.device('/gpu:0'):
        result = a + b
with tf.Session(graph=g1) as sess:
    print(sess.run(result))

有效利用Tensorflow程序中的资源也是计算图的一个重要功能。在计算图中,可以通过集合(collection)来管理不同类别的资源,通过tf.add_to_collection()函数可以将资源加入一个或多个集合中,然后通过tf.get_collection()获取一个集合里面的所有资源,这里的资源可以是张量,变量或者运行Tensorflow程序所需要的队列资源。Tensorflow也自动管理一些最常用的集合。

集合名称 集合内容 使用场景
tf.GraphKeys.VARIABLES 所有变量 持久化Tensorflow模型
tf.GraphKeys.TRAINABLE_VARIABLES 可学习的变量(一般指神经网络中的参数) 模型训练,生成模型可视化内容
tf.GraphKeys.SUMMARIES 日志生成相关的张量 Tensorflow计算可视化
tf.GraphKeys.QUEUE_RUNNERS 处理输入的QueueRunner 输入处理
tf.GraphKeys.MOVING_AVERAGE_VARIABLES 所有计算了滑动平均值的变量 计算变量的滑动平均值

猜你喜欢

转载自blog.csdn.net/weixin_37895339/article/details/78936375