tf.variable_scope(), tf.name_scope(), tf.get_variable(), tf.Variable() 理解总结

许多的 TensorFlow 开源项目都会频繁出现 tf.variable_scope, tf.name_scope, tf.get_variable(), tf.Variable() ,今天来对此做一个总结。

注意,tf.Variable() 有大写!

首先来谈谈 tf.get_variable() 与 tf.Variable(),因为如果使用 variable() 的话每次都会新建变量,但是大多数时候我们是希望一些变量重用的,所以就用到了get_variable()。它会去搜索变量名,然后没有就新建,有就直接用。

既然用到变量名了,就涉及到了名字域的概念。通过不同的域来区别变量名,毕竟让我们给所有变量都直接取不同名字还是有点辛苦的。所以为什么会有 scope 的概念。

语言总是太苍白,直接上几段代码示例对比分析一下:

当用 tf.Variable() 创建相同的变量名时:
import tensorflow as tf

with tf.name_scope('FeiGe'):
    hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
    hf3 = tf.Variable(name='hf2', initial_value=[2], dtype=tf.float32)
    hf4 = tf.Variable(name='hf2', initial_value=[2], dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(hf1.name, sess.run(hf1))
    print(hf3.name, sess.run(hf3))
    print(hf4.name, sess.run(hf4))

运行结果如下:
这里写图片描述

当用 tf.get_variable() 创建相同变量名的变量,但没有设置共享变量时:
import tensorflow as tf

with tf.name_scope('FeiGe'):
    hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
    hf2 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(hf1.name, sess.run(hf1))
    print(hf2.name, sess.run(hf2))

这里写图片描述
重名会报错,询问是否想要把 reuse 值设为 True。

当需要共享变量时,用 tf.variable_scope()
import tensorflow as tf

with tf.variable_scope('FeiGe') as scope:
    hf1 = tf.get_variable(name='hf1', shape=[1], dtype=tf.float32)
    scope.reuse_variables()  # 设置共享变量
    hf1_reuse = tf.get_variable(name='hf1')
    hf2 = tf.Variable(initial_value=[2.], name='hf2', dtype=tf.float32)
    hf2_reuse = tf.Variable(initial_value=[2.], name='hf2', dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(hf1.name, sess.run(hf1))
    print(hf1_reuse.name, sess.run(hf1_reuse))
    print(hf2.name, sess.run(hf2))
    print(hf2_reuse.name, sess.run(hf2_reuse))

这里写图片描述

共享变量还可写为:

with tf.variable_scope('FeiGe') as FeiGe_scope:
    v = tf.get_variable('v', [1])
with tf.variable_scope('FeiGe', reuse=True):
    v1 = tf.get_variable('v')
assert v1 == v

或者:

with tf.variable_scope('FeiGe') as FeiGe_scope:
    v = tf.get_variable('v', [1])
with tf.variable_scope(FeiGe_scope, reuse=True):
    v1 = tf.get_variable('v')
assert v1 == v
为什么需要共享变量?

这里参考网络上的资料,给出一个例子:

当我们研究 GAN 的时候,判别器的任务是,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,生成图像和真实图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制。

发布了225 篇原创文章 · 获赞 648 · 访问量 89万+

猜你喜欢

转载自blog.csdn.net/huangfei711/article/details/81024286
今日推荐