tensorflow入门学习(3)——tensorflow共享变量

当创建复杂的模块时,通常你需要共享大量变量集并且如果你还想在同一个地方初始化这所有的变量,
可以通过共享变量实现。

先看一个图片过滤器的情景:

def my_image_filter(input_images):
    conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
        name="conv1_weights")
    conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
    conv1 = tf.nn.conv2d(input_images, conv1_weights,
        strides=[1, 1, 1, 1], padding='SAME')
    relu1 = tf.nn.relu(conv1 + conv1_biases)

    conv2_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
        name="conv2_weights")
    conv2_biases = tf.Variable(tf.zeros([32]), name="conv2_biases")
    conv2 = tf.nn.conv2d(relu1, conv2_weights,
        strides=[1, 1, 1, 1], padding='SAME')
    return tf.nn.relu(conv2 + conv2_biases)

你很容易想到,模块集很快就比一个模块变得更为复杂

仅在这里我们就有了四个不同的变量:conv1_weights,conv1_biases, conv2_weights, 和conv2_biases.

当我们想重用这个模块时问题还在增多.假设你想把你的图片过滤器运用到两张不同的图片, image1和image2.

你想通过拥有同一个参数的同一个过滤器来过滤两张图片,你可以调用my_image_filter()两次,但是这会产生两组变量.

# First call creates one set of variables.
result1 = my_image_filter(image1)
# Another set is created in the second call.
result2 = my_image_filter(image2)

变量作用域

TensorFlow 提供了变量作用域 机制,当构建一个视图时,很容易就可以共享命名过的变量.
变量作用域机制在TensorFlow中主要由两部分组成:

tf.get_variable(<name>, <shape>, <initializer>): 通过所给的名字创建或是返回一个变量.
tf.variable_scope(<scope_name>): 通过 tf.get_variable()为变量名指定命名空间.

方法 tf.get_variable() 用来获取或创建一个变量,而不是直接调用tf.Variable.
它采用的不是像tf.Variable这样直接获取值来初始化的方法.一个初始化就是一个方法,创建其形状并且为这个形状提供一个张量.这里有一些在TensorFlow中使用的初始化变量:

tf.constant_initializer(value) 初始化一切所提供的值,
tf.random_uniform_initializer(a, b) 从a到b均匀初始化,
tf.random_normal_initializer(mean, stddev) 用所给平均值和标准差初始化均匀分布.
示例:
这里写图片描述
首先因为weights, biases相当于被重复创建4次(2层,2张图片)
那么可以为每层的创建一个variable_scope,这样变量就可以共享

1. tf.get_variable()

v = tf.get_variable(name, shape, dtype, initializer)

情况1:当tf.get_variable_scope().reuse == False时,作用域就是为创建新变量所设置的.

这种情况下,v将通过tf.Variable所提供的形状和数据类型来重新创建.创建变量的全称将会由当前变量作用域名+所提供的名字所组成,并且还会检查来确保没有任何变量使用这个全称.如果这个全称已经有一个变量使用了,那么方法将会抛出ValueError错误

with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])

情况2:当tf.get_variable_scope().reuse == True时,作用域是为重用变量所设置

这种情况下,调用就会搜索一个已经存在的变量,他的全称和当前变量的作用域名+所提供的名字是否相等.如果不存在相应的变量,就会抛出ValueError 错误

with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("v", [1])

2. tf.variable_scope() & tf.variable_scope().reuse_variables()

当前变量作用域可以用tf.get_variable_scope()进行检索并且reuse标签可以通过调用tf.get_variable_scope().reuse_variables()设置为True .

with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])
    tf.get_variable_scope().reuse_variables()
    v1 = tf.get_variable("v", [1])
assert v1 == v

猜你喜欢

转载自blog.csdn.net/qq_37423198/article/details/80526437