tf.get_variable()和tf.Variable()的区别

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/qq_43546676/article/details/102689440

1.遇到关于tf.get_variable()的问题

今天用tensorflow写一个模型,过程中遇到一个很坑的问题,只运行下面这行代码会报错:

W = tf.get_variable('W', (3, 1), initializer=tf.constant_initializer())

# 报错:
ValueError: Variable W already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

当时懵了,我只写了一行代码,上面没有建立名为W的变量,这里共享变量W却报错,说已经存在了,蛇皮玩意???
查了很多资料也没解决,后来意识到,是这串代码运行了多次,W变量已经在内存中了,相当于写了两行W = tf.get_variable('W', (3, 1), initializer=tf.constant_initializer())代码,我们知道,同名的共享变量不能创建两次,因为他不像Variable()能够自动处理命名重复的问题。

2. tf.get_variable()和tf.Variable()的区别

先看一段代码:

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print(w_1.name)
print(w_2.name)
#输出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?
import tensorflow as tf

with tf.variable_scope("scope1"):
    w1 = tf.get_variable("w1", shape=[])
    w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
    w1_p = tf.get_variable("w1", shape=[])
    w2_p = tf.Variable(1.0, name="w2")

print(w1 is w1_p, w2 is w2_p)
#输出
#True  False

区别:

  • 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
  • tf.Variable()每次都在创建新的对象,与name没有关系。而tf.get_variable()对于已经创建的同样name的变量对象,就直接把那个变量对象返回(类似于:共享变量),tf.get_variable() 会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。
  • tf.get_variable():对于在上下文管理器中已经生成一个v的变量,若想通过tf.get_variable函数获取其变量,则可以通过reuse参数的设定为True来获取。
  • 还有一点,tf.get_variable()必须写name,否则报错

https://blog.csdn.net/qq_33915826/article/details/79793171

猜你喜欢

转载自blog.csdn.net/qq_43546676/article/details/102689440
今日推荐