Tensorflow共享变量机制

背景

我们定义变量通常使用tf.Variable()的方式进行创建变量,但在某种情况下,一个模型需要使用其他模型创建的变量,两个模型一起训练。比如:对抗网络中的生成器模型和判别器模型。如果使用Variable进行创建,那么得到的是一个新的变量,而非原来的变量。这时就需引入共享变量解决问题。

tf.get_variable

使用tf.Variable创建变量,每次都会在内存中生成一个新的var。而使用tf.get_variable可以获取一个已经存在的变量或者创建一个新变量。并且get_variable只能定义一次指定名称的变量。

通常get_variable与vaiiable_scope一起使用。variable_scope的意思是变量作用域。类似于c++中的namespace。在某一作用域中的变量可以被设置为共享的方式,被其他的网络模型使用。

在特定作用于域下获取变量

相同作用域中使用get_variable创建两个相同名字的变量是行不通的,如果真的需要创建两个相同名字的变量,则需要使用不同的scope将它们隔开。并且scope支持嵌套。

相同变量名举例:

import tensorflow as tf
with tf.variable_scope("test1"):
    var1 = tf.get_variable("first", shape=[2], dtype=tf.float32)
with tf.variable_scope("test2"):
    var2 = tf.get_variable("first", shape=[2], dtype=tf.float32)
print(var1.name)
print(var2.name)

输出为:

test1/first:0
test2/first:0

支持嵌套举例:

import tensorflow as tf
with tf.variable_scope("test1"):
    var1 = tf.get_variable("first", shape=[2], dtype=tf.float32)
    with tf.variable_scope("test2"):
        var2 = tf.get_variable("second", shape=[2], dtype=tf.float32)
print(var1.name)
print(var2.name)

输出为:

test1/first:0
test1/test2/second:0

实现共享变量的功能

使用get_variable无非就是通过它实现共享变量的功能,故我们可以用variable_scope中的reuse属性,将其设置为True,表示使用已经定义过的变量。此时在get_variable中将不会创建新的变量,而是去图中找与name相同的get_variable。(需要建立相同的scope)

变量共享举例:

import tensorflow as tf
with tf.variable_scope("test1"):
    var1 = tf.get_variable("first", shape=[2], dtype=tf.float32, initializer=tf.constant_initializer(0.1))
with tf.variable_scope("test1", reuse=True):
    var2 = tf.get_variable("first", shape=[2], dtype=tf.float32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(var1))
    print(sess.run(var2))

输出为:

[0.1 0.1]
[0.1 0.1]

可使用tf.AUTO_REUSE来为reuse赋值,可实现第一次调用variable_scope时,传入的reuse值为False, 再次调用variable_scope时传入reuse的值会自动变为True。


tf.name_scope作用域限制op操作,不限制变量的命名。 也就是说变量只受variable_scope的限制,而操作符却要受variable_scope和name_scope的双重限制。

import tensorflow as tf
with tf.variable_scope("test1") as scope1:
    with tf.name_scope("name"):
        var1 = tf.get_variable("first", shape=[2], dtype=tf.float32)
        x = var1 + 1.0
print(var1.name)
print(x.name)

输出结果:

test1/first:0
test1/name/add:0

一个小tip:

name_scope可用空字符使作用域返回到顶层

import tensorflow as tf
with tf.variable_scope("test1") as scope1:
    with tf.name_scope("name"):
        var1 = tf.get_variable("first", shape=[2], dtype=tf.float32)
        x = var1 + 1.0
        with tf.name_scope(""):
            y = 1.0 + x
print(var1.name)
print(x.name)
print(y.name)

输出结果:

test1/first:0
test1/name/add:0
add:0
发布了267 篇原创文章 · 获赞 51 · 访问量 25万+

猜你喜欢

转载自blog.csdn.net/AcSuccess/article/details/89402590