TensorFlow --- 图的基本操作

1.建立图
关于建立图的几个基本操作

tf.Graph() # 建立图
tf.get_default_graph() # 获取图
tf.reset_default_graph() # 重置图

import numpy as np
import tensorflow as tf

# 在默认图里建立的
c = tf.constant(0.0)

# 建立了一个图,并且在新建的图里添加变量,
# 可以通过变量的'.graph'获取所在的图
g = tf.Graph()
# 表示使用tg.Graph函数来创建一个图
with g.as_default():
    c1 = tf.constant(0.1)
    print(c1.graph)
    print(g)
    print(c.graph)

# 获取默认图,所以跟c的值一样
g2 = tf.get_default_graph()
print(g2)

# 重建了一张图代替原来的默认图
tf.reset_default_graph()
g3 = tf.get_default_graph()
print(g3)
结果为:
<tensorflow.python.framework.ops.Graph object at 0x7f9c90074f28>
<tensorflow.python.framework.ops.Graph object at 0x7f9c90074f28>
<tensorflow.python.framework.ops.Graph object at 0x7f9caa041b00>
<tensorflow.python.framework.ops.Graph object at 0x7f9caa041b00>
<tensorflow.python.framework.ops.Graph object at 0x7f9c90074e48>

使用tf.reset_default_graph函数时,必须保证当前图的资源已经全部释放,否则会报错。

2.获取张量
通过get_tensor_by_name可以获得图里面的张量

接上述例子
...
print(c1.name)
t = g.get_tensor_by_name(name='Const:0')
print(t)
结果为:
Const:0
Tensor("Const:0", shape=(), dtype=float32)

3.获取节点的操作
通过get_operation_by_name来获取节点

...
a = tf.constant([1.0, 2.0])
b = tf.constant([1.0], [3.0])

tensor1 = tf.matmul(a, b, name='exampleop')
print(tensor1.name, tensor1)
test = g3.get_tensor_by_name('exampleop:0')
print(test)

print(tensor1.op.name)
testop = g3.get_operation_by_name('exampleop')
print(testop)

with tf.Session() as sess:
    test = sess.run(test)
    print(test)
    test = tf.get_default_graph().get_tensor_by_name('exampleop:0')
    print(test)
结果为:
exampleop:0 Tensor("exampleop:0", shape=(1, 1), dtype=float32)
Tensor("exampleop:0", shape=(1, 1), dtype=float32)

# tensor1.op.name
exampleop

# get_operaion_by_name
name: "exampleop"
op: "MatMul"
input: "Const"
input: "Const_1"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "transpose_a"
  value {
    b: false
  }
}
attr {
  key: "transpose_b"
  value {
    b: false
  }
}

[[7.]]
Tensor("exampleop:0", shape=(1, 1), dtype=float32)

4.获取元素列表
通过get_operaions获取图中的所有元素

...
tt2 = g.get_operations()
print(tt2)
结果为:
[<tf.Operation 'Const' type=Const>]

5.获取对象
通过tf.Graph.as_graph_element(obj, allow_tensor=True, allow_operation=True)函数根据对象获取元素。即传入的是一个对象,返回的是一个张量或一个op。该函数具有验证和转换功能,在多线程方面偶尔会用到。

...
tt3 = g.as_graph_element(c1)
print(tt3)
结果为:
Tensor("Const:0", shape=(), dtype=float32)

猜你喜欢

转载自blog.csdn.net/jian15093532273/article/details/80776519