【tensorflow】 tf.Variable, tf.get_variable之间的区别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shwan_ma/article/details/80488860

之前一直很困惑tf.Variable和tf.get_variable之间的区别,这几天算稍微明白一些。用简单的语言描述概括一下tf.Variable 和 tf.get_variable的一些特性

tf.Variable 和 tf.get_variable之间最主要的区别:
如果tf.Variable定义的时候,两个变量即使重名,那么是依然是两个独立的变量, tensorflow会自动增加变量后缀,以区分同名变量。
而 tf.get_variable定义的时候,tensorflow会自动去检查有没有命名一样的变量,如果出现一样,则会报错。但是如果设置了re-use,则不会报错,同时可以使得参数共享

首先,对于name_scope来说

如果在tf.name_scope下的话定义变量的话,
tf.Variable会在他的名字前面加上该name_scope的名字。
而 tf.get_varoable 则会无视掉name_scope

with tf.name_scope("a_name_scope"):
    var1 = tf.get_variable(name="var1", shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(1.0))
    var2 = tf.Variable(name="var2", initial_value=[2], dtype=tf.float32)
    var3 = tf.Variable(name="var2", initial_value=[2.1], dtype=tf.float32)
    var4 = tf.Variable(name="var2", initial_value=[2.2], dtype=tf.float32)

sess.run(tf.global_variables_initializer())

print(var1.name)
print(sess.run(var1))
print(var2.name)
print(sess.run(var2))
print(var3.name)
print(sess.run(var3))
print(var4.name)
print(sess.run(var4))

输出结果:

var1:0
[ 1.]
a_name_scope/var2:0
[ 2.]
a_name_scope/var2_1:0
[ 2.0999999]
a_name_scope/var2_2:0
[ 2.20000005]


二, 对于variable_scope来说

则不管tf.Variable和tf.get_variable都会加上variable_scope的名字

with tf.variable_scope("a_variable_scope") as scope:
    initializer = tf.constant_initializer(value=3)
    var3 = tf.get_variable('var3', shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(3.0))
    var4 = tf.Variable(name="var4", initial_value=[4], dtype=tf.float32)
    var4_reuse = tf.Variable(name="var4", initial_value=[4], dtype=tf.float32)

sess.run(tf.global_variables_initializer())
print(var3.name)
print(sess.run(var3))
print(var4.name)
print(sess.run(var4))
print(var4_reuse.name)
print(sess.run(var4_reuse))

输出结果:

a_variable_scope/var3:0
[ 3.]
a_variable_scope/var4:0
[ 4.]
a_variable_scope/var4_1:0
[ 4.]

三, 最重要的特性, tf_get_variable支持变量重用

with tf.variable_scope("a_variable_scope") as scope:
    initializer = tf.constant_initializer(value=3)
    var3 = tf.get_variable('var3', shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(3.0))
    scope.reuse_variables()
    var3_reuse = tf.Variable(name="var3")

sess.run(tf.global_variables_initializer())
print(var3.name)
print(sess.run(var3))
print(var3_reuse.name)
print(sess.run(var3_reuse))

输出结果:(重用完成)

a_variable_scope/var3:0
[ 3.]
a_variable_scope/var3:0
[ 3.]

猜你喜欢

转载自blog.csdn.net/shwan_ma/article/details/80488860
今日推荐