Tensorflow命名空间及变量详细解析(三)

版权声明:找不到大腿的时候,让自己变成大腿. https://blog.csdn.net/Xin_101/article/details/88581250

1 Tensorflow命名空间及变量

  • variable_scope,
    新建变量时定义节点(操作)的上下文管理器.解析:新建变量时的管理器,因此对Variableget_variable均有效,因为Variable是新建变量,get_variable是新建或使用定义的变量,都是"新建"的过程.
  • name_scope
    定义Python节点(操作)时的上下文管理器,对get_variable无效,因为get_variable可能没有重新定义节点.
  • get_variable
    功能:获取已存在的变量或新建一个变量.
  • Variable
    功能:通过初始值新建变量.

2 variable_scope()

  • Demo
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope("foo"):
    v1 = tf.Variable([250], name="v_1")
    v2 = tf.get_variable("v_2", initializer=[250.0], dtype=tf.float32)
    print("v1 name: {}".format(v1.name))
    print("v2 name: {}".format(v2.name))
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        print("v2 value: {}".format(v2.eval()))
with tf.variable_scope("foo_2"):
    v1 = tf.Variable([250], name="v_1")
    v2 = tf.get_variable("v_2", shape=[1])
    print("v1 name: {}".format(v1.name))
    print("v2 name: {}".format(v2.name))
  • Result
v1 name: foo/v_1:0
v2 name: foo/v_2:0
v2 value: [250.]
v1 name: foo_2/v_1:0
v2 name: foo_2/v_2:0
  • Analysis
    (1) name_scope对Variable变量提供外层保护,即把变量放在总的命名空间中,变量有一个空间属性,变量名称foo/v_1:0.
    (2) get_variable受name_scope约束,变量名为独立的foo/v_2:0.

3 name_scope()

import tensorflow as tf
tf.reset_default_graph()
v1 = tf.Variable([250], name="v_1")
print("v1 name: {}".format(v1.name))
with tf.name_scope("foo"):
    v1 = tf.Variable([250], name="v_1")
    v2 = tf.get_variable("v_2", [1])
    v3 = tf.Variable([250])
    print("v1 name: {}".format(v1.name))
    print("v2 name: {}".format(v2.name))
    print("v3 name: {}".format(v3.name))
  • Result
v1 name: v_1:0
v1 name: foo/v_1:0
v2 name: v_2:0
v3 name: foo/Variable:0
  • Analysis
    (1) name_scope对Variable变量提供外层保护,即把变量放在总的命名空间中,变量有一个空间属性,变量名称foo/v_1:0,不在命名空间的变量名称v_1:0.
    (2) get_variable不受name_scope约束,变量名为独立的v_2:0.

4 变量共享

  • Demo
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope("foo"):
    v1 = tf.Variable([250], name="v_1")
    v2 = tf.get_variable("v_2", initializer=[250.0], dtype=tf.float32)
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v3 = tf.Variable([250], name="v_1")
    v4 = tf.get_variable("v_2", dtype=tf.float32)
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    g1 = tf.get_default_graph()
    v2_name = v2.name
    v2_tensor_1 = g1.get_tensor_by_name("foo/v_2:0")
    v2_value = sess.run(v2_tensor_1)
    print("v2 name: {}".format(v2_name))
    print("v2 tensor: {}".format(v2_tensor_1))
    print("v2 value in foo: {}".format(v2_value))
    print("-----------")
    v4_name = v4.name
    v2_value = sess.run(v4)
    print("v4 name: {}".format(v4_name))
    print("v4 value in foo: {}".format(v2_value))
  • Result
v2 name: foo/v_2:0
v2 tensor: Tensor("foo/v_2:0", shape=(1,), dtype=float32_ref)
v2 value in foo: [250.]
-----------
v4 name: foo/v_2:0
v4 value in foo: [250.]
  • Analysis
    (1) 变量命名空间可以复用,即重复使用,如上foo命名空间使用了两次,变量可相同也可不相同,若相同则为变量共享,上例的张量foo/v_2:0是共享的.
    (2) 张量名相同,可通过不同的外部变量名获取张量值,如变量v4获取张量foo/v_2:0的值.

5 命名空间变量应用

  • Demo1
import tensorflow as tf
tf.reset_default_graph()

with tf.variable_scope("conv_1"):
    w1 = tf.get_variable("w_1", initializer=[250.0], dtype=tf.float32)
    b1 = tf.get_variable("b_1", initializer=[250.0], dtype=tf.float32)
    print("w1 name: {}".format(w1.name))
    
with tf.variable_scope("conv_2"):
    w2 = tf.get_variable("w_2", initializer=[250.0], dtype=tf.float32)
    b2 = tf.get_variable("b_2", initializer=[250.0], dtype=tf.float32)
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    a = w1.name
    print("a name: {}".format(a))
  • Result 1
w1 name: conv_1/w_1:0
a name: conv_1/w_1:0
  • Demo2
import tensorflow as tf
tf.reset_default_graph()

with tf.name_scope("conv_1"):
    w1 = tf.get_variable("w_1", initializer=[250.0], dtype=tf.float32)
    b1 = tf.get_variable("b_1", initializer=[250.0], dtype=tf.float32)
    print("w1 name: {}".format(w1.name))
    
with tf.name_scope("conv_2"):
    w2 = tf.get_variable("w_2", initializer=[250.0], dtype=tf.float32)
    b2 = tf.get_variable("b_2", initializer=[250.0], dtype=tf.float32)
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    a = w1.name
    g = tf.get_default_graph()
    w1_tensor = g.get_by_tensor_name(a)
    w1_value_by_run = sess.run(w1_tensor)
    w1_value_by_eval = w1_tensor.eval()
    print("a name: {}".format(a))
    print("w1 tensor: {}".format(w1_tensor))
    print("w1 value by run:{}".format(w1_value_by_run))
    print("w1 value by eval:{}".format(w1_value_by_eval))
  • Result 2
w1 name: w_1:0
a name: w_1:0
w1 tensor: Tensor("conv_1/w_1:0", shape=(1,), dtype=float32_ref)
w1 value by run: [250.0]
w1 value by eval: [250.0]
  • Analysis
    (1) 变量名同为w_1,在不同空间中定义,不冲突,因为变量空间对各自的变量其隔离保护作用;
    (2) 通过name属性可获取变量的完整名称如conv_1/w_1:0张量(Tensor)名称,可通过sess.runeval()获取张量值.

6 节点,张量及张量值获取

  • Demo
import tensorflow as tf
tf.reset_default_graph()

with tf.variable_scope("conv_1"):
    w1 = tf.get_variable("w_1", initializer=[250.0], dtype=tf.float32)
    b1 = tf.get_variable("b_1", initializer=[250.0], dtype=tf.float32)
    
with tf.variable_scope("conv_2"):
    w2 = tf.get_variable("w_2", initializer=[250.0], dtype=tf.float32)
    b2 = tf.get_variable("b_2", initializer=[250.0], dtype=tf.float32)
with tf.variable_scope("conv_3"):
    w3 = tf.Variable([250.0], name="w_3")
    b3 = tf.Variable([250.0], name="b_3")
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # 获取默认图结构
    g = tf.get_default_graph()
    print("------------")
    # 直接输出张量
    w3_direct_output = w3
    # 输出张量名称
    w3_name = w3.name
    # 通过节点名称获取变量节点即运算操作
    w3_operation = g.get_operation_by_name("conv_3/w_3")
    # 通过节点名称获取节点输出list
    w3_outputs = g.get_operation_by_name("conv_3/w_3").outputs
    # 通过节点名称获取节点输出list内容
    w3_outputs_value = g.get_operation_by_name("conv_3/w_3").outputs[0]
    # 通过张量名获取张量
    w3_tensor = g.get_tensor_by_name("conv_3/w_3:0")
    # 通过节点输出list内容获取变量值
    w3_value_by_operation = sess.run(w3_outputs_value)
    # 通过张量获取变量值
    w3_value_by_tensor = sess.run(w3_tensor)
    print("w3 direct outputs: {}".format(w3_direct_output))
    print("w3 name: {}".format(w3_name))
    print("w3 operation: {}".format(w3_operation))
    print("w3 outputs: {}".format(w3_outputs))
    print("w3 outputs value: {}".format(w3_outputs_value))
    print("w3 tensor: {}".format(w3_tensor))
    print("w3 value by operation: {}".format(w3_value_by_operation))
    print("w3 value by tensor: {}".format(w3_value_by_tensor))
  • Result
------------
w3 direct outputs: <tf.Variable 'conv_3/w_3:0' shape=(1,) dtype=float32_ref>
w3 name: conv_3/w_3:0
w3 operation: name: "conv_3/w_3"
op: "VariableV2"
attr {
  key: "container"
  value {
    s: ""
  }
}
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
      }
    }
  }
}
attr {
  key: "shared_name"
  value {
    s: ""
  }
}

w3 outputs: [<tf.Tensor 'conv_3/w_3:0' shape=(1,) dtype=float32_ref>]
w3 outputs value: Tensor("conv_3/w_3:0", shape=(1,), dtype=float32_ref)
w3 tensor: Tensor("conv_3/w_3:0", shape=(1,), dtype=float32_ref)
w3 value by operation: [250.]
w3 value by tensor: [250.]
  • Analysis
    (1) Tensorflow定义的变量,直接输出的为张量Tensor,张量是存储数据的变量,名称为name:number,即变量名称结合Tensorflow的编号,如conv_3/w_3:0.
    (2) 节点的名称为name即直接定义的name就是节点名称,如conv_3/w_3.
    (3) 结合以上两点,get_tensor_by_name("conv_3/w_3:0")通过张量名称获取张量,get_by_operation_name("conv_3/w_3")获取节点内容.
    (4) 节点的输出内容有name,op,attr,比较多,但是可使用outputs属性获取节点的输出列表,列表的内容就是张量.
    (5) 获取变量的内容有两种方式:通过张量get_by_tensor_name,通过节点输出的内容:get_by_operation_name().outputs[0].
    (6) 通过张量或节点名称获取内容时需要先获取图结构,即定义获取图结构:tf.get_default_graph().

[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/get_variable
[2]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/Variable
[3]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/name_scope
[4]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/variable_scope


猜你喜欢

转载自blog.csdn.net/Xin_101/article/details/88581250