TensorFLow变量管理与变量共享

今天将tf.Variable和tf.get_variable变量的使用记录一下,在实现gan时着实踩了很深的坑,总是效果不好,也没有报错,结果发现是共享权重没有处理好,最后终于整好了,贼开心呢,希望大家能够避免踩坑。

1、tf.Variable的使用

调用方式:

weights = tf.Variable(tf.constant(0.1, shape = shape), name = "weights")

2、tf.get_variable的使用

调用方式:

weights = tf.get_variable("weights", shape,
         initializer = tf.truncated_normal_initializer(stddev = 0.1))

3、两者区别

3.1

tf.Variable,当重复调用时,它会自动创建新的变量名:

def test():
    #在layer1命名空间内创建变量,默认reuse=False
    with tf.variable_scope('D_layer1'):
        weights1 = tf.Variable(tf.constant(0.1, shape = [5]), name = "weights") 
        name1 = weights1.name
    #在layer2命名空间内创建变量,默认reuse=False
    with tf.variable_scope('D_layer2'):
        weights2 = tf.Variable(tf.constant(0.1, shape = [5]), name = "weights") 
        name2 = weights2.name
    return name1, name2

tf.variable_scope(‘D_layer1’)会创建一个名为D_layer1的空间名,其下的所有变量名是在它的子空间来命名,如上函数,重复调用结果如下:

import variabletest
import tensorflow as tf
import numpy as np

name11, name12 = variabletest.test()
name21, name22 = variabletest.test()
print(name11)
print(name12)
print(name21)
print(name22)

D_layer1/weights:0
D_layer2/weights:0
D_layer1_1/weights:0
D_layer2_1/weights:0

3.2、实现共享变量

tf.get_variable,当重复调用时,它会自动创建新的变量名:

def test(reuse):
    #在layer1命名空间内创建变量,默认reuse=False
    with tf.variable_scope('D_layer1', reuse = reuse):
        weights1 = tf.get_variable("weights", [5], initializer = tf.truncated_normal_initializer(stddev = 0.1)) 
        name1 = weights1.name
    #在layer2命名空间内创建变量,默认reuse=False
    with tf.variable_scope('D_layer2', reuse = reuse):
        weights2 = tf.get_variable("weights", [5], initializer = tf.truncated_normal_initializer(stddev = 0.1)) 
        name2 = weights2.name
    return name1, name2

在类似gan网络中,我们需要共享权重,这样就会多次调用同一个前向传播的函数,但是若使用tf.Variable达不到共享权重的目的,除非将tf.Variable放置主函数中,但是这样封装性不好,所以就可以使用tf.get_variable,配合tf.variable_scope一起使用,结果如下:

import variabletest
import tensorflow as tf
import numpy as np

name11, name12 = variabletest.test(False)
name21, name22 = variabletest.test(True)
print(name11)
print(name12)
print(name21)
print(name22)

D_layer1/weights:0
D_layer2/weights:0
D_layer1/weights:0
D_layer2/weights:0

当tf.variable_scope的reuse设置为False时,他会自动创建新的变量,当为True时,他会从已有的变量中查询并使用。从而上述代码即可完成变量的共享使用。。。

猜你喜欢

转载自blog.csdn.net/michael__corleone/article/details/78906318