引言:
tf提供了一种通过变量名称来创建或者获取一个变量的机制。通过这个机制,在不同的函数中可以直接通过变量的名字来获取变量,而不需要将变量通过参数的形式到处传递。
tf中通过变量名字获取变量的机制主要是通过tf.get_variable 和tf.variable_scope 函数实现的。
tf.get_variable 和 tf.Variable :
下面两个定义是等价的:
v = tf.get_variable("v", shape=[1], initializer=tf.constant_initializer(1.0)) v = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
二者区别:
对于 tf.Variable函数,变量名称是一个可选的参数,通过 name='v'的形式给出,但是对于 tf.get_variable 函数,变量名是一个必填的参数。tf.get_variable会根据这个名字去获取变量。它可以避免同名错误。
tf.get_variable 和 tf.variable_scope :
若需要通过tf.get_variable 获取一个已经生成的变量,需要通过tf.variable_scope 函数来生成一个上下文管理器,并指定在这个上下文管理器中,tf.get_variable 函数将直接获取已经生成的变量。
#在名字为foo的命名空间中创建名字为v的变量 with tf.variable_scope("foo"): v = tf.get_variable("v", initializer=tf.constant_initializer(1.0)) #在上下文管理器中,将参数reuse设置为True,这样tf.get_variable 函数将直接获取已经声明的变量 #将参数reuse设置为True时,将只能获取已经创建过的变量 with tf.variable_scope("foo", reuse=True): v1 = tf.get_variable("v", [1])当tf.variable_scope 使用参数 reuse=None 或者是reuse=False创建上下文管理器,tf.get_variable将创建新的变量。如果同名的变量已经存在,那么报错。
tf中的tf.variable_scope函数是可以嵌套的。