tensorflow之理解name_scope, variable_scope

       在训练深度网络时,为了减少需要训练参数的个数(比如具有simase结构的LSTM模型)、或是多机多卡并行化训练大数据大模型(比如数据并行化)等情况时,往往需要共享变量。另外一方面是当一个深度学习模型变得非常复杂的时候,往往存在大量的变量和操作,如何避免这些变量名和操作名的唯一不重复,同时维护一个条理清晰的graph非常重要。因此,tensorflow中用tf.Variable(),tf.get_variable(),tf.Variable_scope(),tf.name_scope()几个函数来实现:

1、tf.Variable(<variable_name>),tf.get_variable(<variable_name>):

tf.Variable和tf.get_variable都是用于在一个name_scope下面获取或创建一个变量的两种方式,区别在于:

①tf.Variable用于创建一个新变量,在同一个name_scope下面,可以创建相同名字的变量,底层实现会自动引入别名机制,两次调用产生了其实是两个不同的变量,使用tf.Variable的话每次都会新建变量。对于使用tf.Variable来说,tf.name_scope和tf.variable_scope功能一样,都是给变量加前缀,相当于分类管理,模块化。

tf.get_variable用于获取一个变量,并且不受name_scope的约束。当这个变量已经存在时,则自动获取;如果不存在,则自动创建一个变量。对于tf.get_variable来说,tf.name_scope对其无效,也就是说tf认为当使用tf.get_variable时,只归属于tf.variable_scope来管理共享与否。

②tf.Variable会自动检测命名冲突并自行处理。                                                               

tf.get_variable则遇到重名的变量创建且变量名没有设置为共享变量时,则会报错。

import tensorflow as tf

with tf.name_scope('name_scope_x'):
    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var3.name, sess.run(var3))
    print(var4.name, sess.run(var4))
# 输出结果:
# var1:0 [-0.30036557]   可以看到前面不含有指定的'name_scope_x'
# name_scope_x/var2:0 [ 2.]
# name_scope_x/var2_1:0 [ 2.]  可以看到变量名自行变成了'var2_1',避免了和'var2'冲突
import tensorflow as tf

with tf.name_scope('name_scope_1'):
    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var2.name, sess.run(var2))

# ValueError: Variable var1 already exists, disallowed. Did you mean 
# to set reuse=True in VarScope? Originally defined at:
# var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)

2、tf.name_scope(<scope_name>)与tf.variable_scope(<scope_name>):

tf.name_scope:

主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的 堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。

tf.variable_scope:

通过设置reuse 标志以及初始化方式来影响域下的变量。一般与tf.name_scope()配合使用,用于管理一个graph中变量的名字,避免变量之间的命名冲突,tf.variable_scope允许在一个variable_scope下面共享变量。 需要注意的是:创建一个新的variable_scope时不需要把reuse属性设置未False,只需要在使用的时候设置为True就可以了。 在重复使用的时候, 一定要在代码中强调scope.reuse_variables(),否则系统将会报错, 以为你只是单纯的不小心重复使用到了一个变量。

with tf.variable_scope("image_filters1") as scope1:
    result1 = my_image_filter(image1)

with tf.variable_scope(scope1, reuse = True):
    result2 = my_image_filter(image2)
import tensorflow as tf

with tf.variable_scope('variable_scope_y') as scope:
    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    scope.reuse_variables()  # 设置共享变量
    var1_reuse = tf.get_variable(name='var1')
    var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
    var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var1_reuse.name, sess.run(var1_reuse))
    print(var2.name, sess.run(var2))
    print(var2_reuse.name, sess.run(var2_reuse))
# 输出结果:
# variable_scope_y/var1:0 [-1.59682846]
# variable_scope_y/var1:0 [-1.59682846]   可以看到变量var1_reuse重复使用了var1
# variable_scope_y/var2:0 [ 2.]
# variable_scope_y/var2_1:0 [ 2.]

也可以这样

with tf.variable_scope('foo') as foo_scope:
    v = tf.get_variable('v', [1])
with tf.variable_scope('foo', reuse=True):
    v1 = tf.get_variable('v')
assert v1 == v

或者这样

with tf.variable_scope('foo') as foo_scope:
    v = tf.get_variable('v', [1])
with tf.variable_scope(foo_scope, reuse=True):
    v1 = tf.get_variable('v')
assert v1 == v

要注意的是:

with tf.variable_scope('scp', reuse=True) as scp:
    a = tf.get_varialbe('a') #报错

with tf.variable_scope('scp', reuse=False) as scp:    
     a = tf.get_varialbe('a')
    a = tf.get_varialbe('a') #报错

都会报错,因为reuse=True是,get_variable会强制共享,如果不存在,报错;reuse=Flase时,会强制创造,如果已经存在,也会报错。

如果想实现“有则共享,无则新建”的方式,可以:

with tf.variable_scope('scp', reuse=tf.AUTO_REUSE) as scp:    
     a = tf.get_variable('a') #无,创造
     a = tf.get_variable('a') #有,共享

3、总结:

tf.variable_scope和tf.get_variable必须要搭配使用(全局scope除外),为share提供支持。

tf.Variable可以单独使用,也可以搭配tf.name_scope使用,给变量分类命名,模块化。

tf.Variable和tf.variable_scope搭配使用不伦不类,不是设计者的初衷。

参考:https://www.zhihu.com/question/54513728

猜你喜欢

转载自blog.csdn.net/lcczzu/article/details/84921096
今日推荐