TensorFlow 共享变量

本章介绍TensorFlow中非常常用的共享变量的使用。其他部分前往:TensorFlow 学习目录

目录

一、对 tf.Variable() 的讨论

二、摒弃 tf.Variable() 使用 tf.variable_scope() 与 tf.get_variable() 的组合

三、with tf.variable_scope as vs: 对 tf.variable_scope() 的影响

四、使用tf.name_scope,这个命名空间不作用在变量上,而是只作用在OP上面


一、对 tf.Variable() 的讨论

首先我们知道 tf.Variable() 函数的使用方法

import tensorflow as tf
var1 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var1.name)

输出

var:0

现在思考,如果我想继续再其他的模型中使用这个变量怎么办,比如,GAN网络中的生成器和判别器,如果要是使用 tf.Variable()然后用同样的变量名字,那样会得到一个新的变量而不是我们原先需要。

import tensorflow as tf

var1 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var1.name)

var2 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var2.name)

输出

var:0
var_1:0

可以看到,系统直接给了一个新的变量名字,而不是使用之前我们定义的那个 var1。

二、摒弃 tf.Variable() 使用 tf.variable_scope() 与 tf.get_variable() 的组合

  • tf.get_variable(),如果变量的名字没有被使用过,那么该语句就是建立一个新的变量,和 tf.Variable() 没有任何的区别,如果变量的名字之前在该“图”中被使用过,那么如果直接使用这个语句会报错,因为一个“图中” tf.get_variable()只能定义同一个名字的变量(样例:code_1),所以如果此时需要共享之前的那个变量,需要配合 tf.variable_scope()(样例:code_2)。
  • tf.variable_scope() 相当于一个命名空间,然后可以嵌套使用,可以当作是一个地址,路径之类的东西。
# code_1:

import tensorflow as tf

var3 = tf.get_variable(name='var_', dtype=tf.float32, initializer=tf.constant(3.3333))
print (var3.name)

var4 = tf.get_variable(name='var_', dtype=tf.float32)
print (var4.name)

输出,报错

  File "D:/pycodeLIB/TensorFlow/test.py", line 22, in <module>
    var4 = tf.get_variable(name='var_', dtype=tf.float32)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1328, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1090, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 435, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 404, in _true_getter
    use_resource=use_resource, constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 743, in _get_single_variable
    name, "".join(traceback.format_list(tb))))
ValueError: Variable var_ already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 3414, in create_op
    op_def=op_def)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)

下面的代码是正确的使用方法,而且代码中我特意使用了嵌套 tf.variable_scope()的形式。

# code_2

import tensorflow as tf

with tf.variable_scope('test1'):
    with tf.variable_scope('test2'):
        var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
        print (var3.name)

with tf.variable_scope('test1', reuse=True):
    with tf.variable_scope('test2'):
        var4 = tf.get_variable('var1', dtype=tf.float32)
        print (var4.name)

输出结果,可以从结果中看出,此时使用的是用一个参数

test1/test2/var1:0
test1/test2/var1:0

三、with tf.variable_scope as vs: 对 tf.variable_scope() 的影响

如果对一个tf.variable_scope()的嵌套结构的内层 variable_scope初始化为vs那么此时,那么其将不受外部 variable_scope() 的影响,通过比较下面两个代码,来感受一下(该代码和code_2区别在第4行和倒数第3行)

import tensorflow as tf

with tf.variable_scope('test1'):
    with tf.variable_scope('test2') as vs:
        var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
        print (var3.name)

with tf.variable_scope('test1', reuse=True):
    with tf.variable_scope(vs):
        var4 = tf.get_variable('var1', dtype=tf.float32)
        print (var4.name)

输出报错:

Traceback (most recent call last):
  File "D:/pycodeLIB/TensorFlow/test.py", line 16, in <module>
    var4 = tf.get_variable('var1', dtype=tf.float32)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1328, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1090, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 435, in get_variable
    constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 404, in _true_getter
    use_resource=use_resource, constraint=constraint)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 743, in _get_single_variable
    name, "".join(traceback.format_list(tb))))
ValueError: Variable test1/test2/var1 already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 3414, in create_op
    op_def=op_def)
  File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)

该错误很明显,是由上面提到的问题的真是反应,就是说,vs不受上一层variable_scope()的影响,所以vs初始化的variable_scope()不可以去共享变量,因为其reuse参数没有设置为True。下面更改之后的代码为

import tensorflow as tf

with tf.variable_scope('test1'):
    with tf.variable_scope('test2') as vs:
        var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
        print (var3.name)

with tf.variable_scope('test1', reuse=True):
    with tf.variable_scope(vs, reuse=True):
        var4 = tf.get_variable('var1', dtype=tf.float32)
        print (var4.name)

输出正确

test1/test2/var1:0
test1/test2/var1:0

四、使用tf.name_scope,这个命名空间不作用在变量上,而是只作用在OP上面

import tensorflow as tf

with tf.variable_scope('test3'):
    with tf.name_scope('op_test'):
        v = tf.get_variable('v', dtype=tf.float32, initializer=tf.random_normal(shape=[6], mean=0.0, stddev=1.0))
        xx = 1.0 + v

print (v.name)
print (xx.op.name)

输出

test3/v:0
test3/op_test/add

可以看出其中的变量没有收到‘op_test'空间的影响,但是OP操作add却受到了'op_test'命名空间的影响。

发布了331 篇原创文章 · 获赞 135 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/Triple_WDF/article/details/103201768
今日推荐