TensorFlow的变量管理:变量作用域机制

在深度学习中,你可能需要用到大量的变量集,而且这些变量集可能在多处都要用到。例如,训练模型时,训练参数如权重(weights)、偏置(biases)等已经定下来,要拿到验证集去验证,我们自然希望这些参数是同一组。以往写简单的程序,可能使用全局限量就可以了,但在深度学习中,这显然是不行的,一方面不便管理,另外这样一来代码的封装性受到极大影响。因此,TensorFlow提供了一种变量管理方法:变量作用域机制,以此解决上面出现的问题。

TensorFlow的变量作用域机制依赖于以下两个方法,官方文档中定义如下:

[plain]  view plain  copy
 
  1. tf.get_variable(name, shape, initializer): Creates or returns a variable with a given name.建立或返回一个给定名称的变量  
  2. tf.variable_scope( scope_name): Manages namespaces for names passed to tf.get_variable(). 管理传递给tf.get_variable()的变量名组成的命名空间  

先说说tf.get_variable(),这个方法在建立新的变量时与tf.Variable()完全相同。它的特殊之处在于,他还会搜索是否有同名的变量。创建变量用法如下:

[plain]  view plain  copy
 
  1. with tf.variable_scope("foo"):  
  2.     with tf.variable_scope("bar"):  
  3.         v = tf.get_variable("v", [1])  
  4.         assert v.name == "foo/bar/v:0"  


而tf.variable_scope(scope_name),它会管理在名为scope_name的域(scope)下传递给tf.get_variable的所有变量名(组成了一个变量空间),根据规则确定这些变量是否进行复用。这个方法最重要的参数是reuse,有None,tf.AUTO_REUSE与True三个选项。具体用法如下:

  1. reuse的默认选项是None,此时会继承父scope的reuse标志。
  2. 自动复用(设置reuse为tf.AUTO_REUSE),如果变量存在则复用,不存在则创建。这是最安全的用法,在使用新推出的EagerMode时reuse将被强制为tf.AUTO_REUSE选项。用法如下:
    [plain]  view plain  copy
     
    1. def foo():  
    2.   with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):  
    3.     v = tf.get_variable("v", [1])  
    4.   return v  
    5.   
    6. v1 = foo()  # Creates v.  
    7. v2 = foo()  # Gets the same, existing v.  
    8. assert v1 == v2  
  3. 复用(设置reuse=True):
    [plain]  view plain  copy
     
    1. with tf.variable_scope("foo"):  
    2.   v = tf.get_variable("v", [1])  
    3. with tf.variable_scope("foo", reuse=True):  
    4.   v1 = tf.get_variable("v", [1])  
    5. assert v1 == v  
  4. 捕获某一域并设置复用(scope.reuse_variables()):
    [plain]  view plain  copy
     
    1. with tf.variable_scope("foo") as scope:  
    2.   v = tf.get_variable("v", [1])  
    3.   scope.reuse_variables()  
    4.   v1 = tf.get_variable("v", [1])  
    5. assert v1 == v  

    1)非复用的scope下再次定义已存在的变量;或2)定义了复用但无法找到已定义的变量,TensorFlow都会抛出错误,具体如下:
[plain]  view plain  copy
 
  1. with tf.variable_scope("foo"):  
  2.     v = tf.get_variable("v", [1])  
  3.     v1 = tf.get_variable("v", [1])  
  4.     #  Raises ValueError("... v already exists ...").  
  5.   
  6.   
  7. with tf.variable_scope("foo", reuse=True):  
  8.     v = tf.get_variable("v", [1])  
  9.     #  Raises ValueError("... v does not exists ...").  
 
转自: https://blog.csdn.net/zbgjhy88/article/details/78960388

猜你喜欢

转载自www.cnblogs.com/pzf9266/p/9012296.html