许多的 TensorFlow 开源项目都会频繁出现 tf.variable_scope, tf.name_scope, tf.get_variable(), tf.Variable() ,今天来对此做一个总结。
注意,tf.Variable() 有大写!
首先来谈谈 tf.get_variable() 与 tf.Variable(),因为如果使用 variable() 的话每次都会新建变量,但是大多数时候我们是希望一些变量重用的,所以就用到了get_variable()。它会去搜索变量名,然后没有就新建,有就直接用。
既然用到变量名了,就涉及到了名字域的概念。通过不同的域来区别变量名,毕竟让我们给所有变量都直接取不同名字还是有点辛苦的。所以为什么会有 scope 的概念。
语言总是太苍白,直接上几段代码示例对比分析一下:
当用 tf.Variable() 创建相同的变量名时:
import tensorflow as tf
with tf.name_scope('FeiGe'):
hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
hf3 = tf.Variable(name='hf2', initial_value=[2], dtype=tf.float32)
hf4 = tf.Variable(name='hf2', initial_value=[2], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(hf1.name, sess.run(hf1))
print(hf3.name, sess.run(hf3))
print(hf4.name, sess.run(hf4))
运行结果如下:
当用 tf.get_variable() 创建相同变量名的变量,但没有设置共享变量时:
import tensorflow as tf
with tf.name_scope('FeiGe'):
hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
hf2 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(hf1.name, sess.run(hf1))
print(hf2.name, sess.run(hf2))
重名会报错,询问是否想要把 reuse 值设为 True。
当需要共享变量时,用 tf.variable_scope()
import tensorflow as tf
with tf.variable_scope('FeiGe') as scope:
hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
scope.reuse_variables() # 设置共享变量
hf1_reuse = tf.get_variable(name='hf1')
hf2 = tf.Variable(initial_value=[2.], name='hf2', dtype=tf.float32)
hf2_reuse = tf.Variable(initial_value=[2.], name='hf2', dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(hf1.name, sess.run(hf1))
print(hf1_reuse.name, sess.run(hf1_reuse))
print(hf2.name, sess.run(hf2))
print(hf2_reuse.name, sess.run(hf2_reuse))
共享变量还可写为:
with tf.variable_scope('FeiGe') as FeiGe_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope('FeiGe', reuse=True):
v1 = tf.get_variable('v')
assert v1 == v
或者:
with tf.variable_scope('FeiGe') as FeiGe_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope(FeiGe_scope, reuse=True):
v1 = tf.get_variable('v')
assert v1 == v
为什么需要共享变量?
这里参考网络上的资料,给出一个例子:
当我们研究 GAN 的时候,判别器的任务是,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,生成图像和真实图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制。