TensorFlow --- 共享变量

1.使用get_variable获取变量
get_variable一般会配合variable_scope使用,以实现共享变量。
variable_scope的意思是变量作用域。

# get_variable的定义

tf.get_variable(<name>, <shape>,<initializer>)

使用一般通过name属性定位到具体变量,并将其共享到其他模型中。

2.get_variable和Variable的区别
(1)Variable的用法

    import tensorflow as tf

    var1 = tf.Variable(1.0, name='firstvar')
    print('var1',var1.name)
    var1 = tf.Variable(2.0, name='firstvar')
    print('var1',var1.name)
    var2 = tf.Variable(3.0)
    print('var2',var2.name)
    var2 = tf.Variable(4.0)
    print('var2',var2.name)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # eval()函数:取返回值
        print('var1=', var1.eval())
        print('var2=', var2.eval())

    结果为:
    var1 firstvar:0
    var1 firstvar_1:0
    var2 Variable:0
    var2 Variable_1:0
    var1=2.0
    var2=4.0

上述代码中定义了两次var1,可以看到在内存中生成了两个var1,对于图来讲后面的var1是生效的。当Variable定义是没有指定名字,系统会自动加上一个名字Variable:0

(2)get_variable用法

接上述代码
...
get_var1 = tf.get_variable('firstvar', [1], initializer=tf.constant_initializer(0.3))
print('get_var1:',get_var1.name)
get_var2 = tf.get_variable('firstvar', [1], initializer=tf.constant_initializer(0.4))
print('get_var1:',get_var1.name)

结果为:
get_var1: firstvar_2:0
ValueError: Variable firstvar already exists, disallowed.

可以看到,程序在定义第二个get_variable1时发生错误了。这表明,使用get_variable只能定义一次指定名称的变量。同时由于变量firstvar在前面使用Variable函数生成过一次,所有系统自动变成了firstvar_2:0


修改后的代码为

get_var1 = tf.get_variable('firstvar', [1], initializer=tf.constant_initializer(0.3))
print('get_var1:',get_var1.name)
get_var1 = tf.get_variable('firstvar1', [1], initializer=tf.constant_initializer(0.4))
print('get_var1:',get_var1.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('get_var1=',get_var1.eval())

结果为:
get_var1: firstvar_2:0
get_var1: firstvar1:0
get_var1=[0.40000001]

3.在特定作用域下获取变量

import tensorflow as tf

with tf.variable_scope('test1'):
    var1 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
with tf.variable_scope('test2'):
    var2 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)

print('var1:', var1.name)
print('var2', var2.name)

结果为:
var1: test1/firstvar:0
var2: test2/firstvar:0

var1和var2都使用firstvar的名字来定义。通过输出可以看出,其实生成的两个变量var1和var2是不同的,它们作用在不同的scope下,这就是scope的作用
scope还支持嵌套,将上面代码中的第二个scope缩进以下,得到如下代码:

with tf.variable_scope('test1'):
    var1 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
    with tf.variable_scope('test2'):
        var2 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)

print('var1:', var1.name)
print('var2', var2.name)

结果为:
var1: test1/firstvar:0
var2: test1/test2/firstvar:0

4.共享变量功能的实现

import tensorflow as tf

with tf.variable_scope('test1'):
    var1 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
with tf.variable_scope('test2'):
    var2 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)

print('var1:', var1.name)
print('var2', var2.name)

with tf.variable_scope('test1', reuse=True):
    var3 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
    with tf.variable_scope('test2'):
        var4 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)

print('var3:', var3.name)
print('var4:', var4.name)

结果为:
var1: test1/firstvar:0
var2: test1/test2/firstvar:0
var3: test1/firstvar:0
var4: test1/test2/firstvar:0

通过上述代码表明var1和var3共用了一个变量,var2和var4共用了一个变量,这就实现了共享变量。在实际应用中,可以把var1和var2放到一个网络模型里取训练,把var3和var4放到另一个网络模型里去训练,而两个模型的训练结果都会作用于一个模型的学习参数上。

5.初始化共享变量的作用域
variable_scope和get_variable都有初始化的功能。在初始化时,如果没有对当前变量初始化,TensorFlow会默认使用作用域的初始化方法对其初始化,并且作用域的初始化方法也有继承功能

import tensorflow as tf

with tf.variable_scope('test1', initializer=tf.constant_initializer(0.4)):
    var1 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
    with tf.variable_scope('test2'):
        var2 = tf.get_variable('firstvar', shape=[2], dtype=tf.float32)
        var3 = tf.get_variable('var3', shape=[2], initializer=tf.constant_initializer(0.3))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('var1=', var1.eval()) # 作用域test1下的变量
    print('var2=', var2.eval()) # 作用域test2下的变量,继承test1的初始化
    print('var3=', var3.eval()) # 作用域test2下的变量

结果为:
var1 = [0.4 0.4]
var2 = [0.4 0.4]
var3 = [0.3 0.3]

var1数组值为0.4,表明继承了test1的值;var2数组值为0.4,表明其所在的作用域test2也继承了test1的初始化;变量var3在在创建是同步指定了初始化操作,所以数组值为0.3

6.作用域与操作符的受限范围

# -*-coding:utf-8 -*-

import tensorflow as tf

with tf.variable_scope('scope1') as sp:
    var1 = tf.get_variable('v', [1])

print('sp:',sp.name)
print('var1:', var1.name)

with tf.variable_scope('scope2'):
    var2 = tf.get_variable('v', [1])

    with tf.variable_scope(sp) as sp1:
        var3 = tf.get_variable('v3', [1])

print('sp1:', sp1.name)
print('var2:', var2.name)
print('var3:', var3.name)
结果为:
sp: scope1
var1: scope1/v:0
sp1: scope1
var2: scope2/v:0
var3: scope1/v3:0

从上述的代码可以看出,sp1在scope2下,但是输出的仍然是scope1,没有改版。在它下面定义的var3的名字是scope1/v3:0,表明也在scope1下,再次说明sp没有收到外层的限制。

with tf.variable_scope('scope'):
    with tf.name_scope('bar'):
        v = tf.get_variable('v', [1])
        x = 1.0 + v
print('v:', v.name)
print('x.op:', x.op.name)
结果为:
v: scope/v:0
x.op: scope/bar/add

从上述代码可以看出,操作符不仅受到tf.name_scope作用域的限制,同时也受到tf.variable_scope作用域的限制。tf.name_scope只能限制op,不能限制变量的命名。


比较tf.name_scope与variable_scope在空字符串情况下的处理

with tf.variable_scope('scope2'):
    var2 = tf.get_variable('v', [1])
    with tf.variable_scope(sp) as sp1:
        var3 = tf.get_variable('v3', [1])
        with tf.variable_scope(''):
            var4 = tf.get_variable('v4', [1])

with tf.variable_scope('scope'):
    with tf.name_scope('bar'):
        v = tf.get_variable('v', [1])
        x = 1.0 + v
        with tf.name_scope(''):
            y = 1.0 + v
print('v4:', v4.name)
print('y.op:', y.op.name)
结果为:
var4: scope1//v4:0
y.op: add

从上述代码可以看出,y变成顶层了,而var4多了一个空层。在tf.name_scope函数中,可以使用空字符将作用域返回到顶层。

猜你喜欢

转载自blog.csdn.net/jian15093532273/article/details/80764739