tf.add_to_collection 和 tf.get_collection 和 tf.add_n

简要介绍

tf.add_to_collection:把多个变量放在一个 自己命名 的集合里,包括不同域内的变量

tf.get_collection:读取一个列表,生成一个新列表

tf.add_n:把一个列表里的元素求和

add_to_collection(name, value)

name 为集合名,value 为 变量;

通常 tensorflow 会把变量和可训练的变量自动收集起来,包括不同域的变量;

变量对应的集合名字叫 variables,或者叫 tf.GraphKeys.VARIABLES;

可训练的变量对应的集合名字为 trainable_variables,或者叫 tf.GraphKeys.TRAINABLE_VARIABLES;

print(tf.GraphKeys.VARIABLES)           # variables
print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables
扫描二维码关注公众号,回复: 9644371 查看本文章

示例

with tf.name_scope('test1') as test1:
    v1 = tf.Variable(1)
    tf.add_to_collection('all', v1)     ### 显式加入集合

with tf.name_scope('test2') as test2:
    v2 = tf.Variable(2)
    tf.add_to_collection('all', v2)     ### 显式加入集合

for i in tf.get_collection(tf.GraphKeys.VARIABLES):     ### tf 自动收集
    print(i)
# <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref>
# <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref>

for j in tf.get_collection('all'):
    print(j)
# <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref>
# <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref>

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.add_n(tf.get_collection('all')))) # 3

它的作用是不停地记录关注变量,然后求和

d1 = tf.Variable(1)
d2 = tf.Variable(2)
d3 = tf.Variable(3)
tf.add_to_collection('sum', d1)
tf.add_to_collection('sum', d2)
tf.add_to_collection('sum', d3)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.add_n(tf.get_collection('sum'))))     # 6

get_collection(key, scope=None)

key 集合名,scope 作用域

示例

v1 = tf.Variable(1, name='v1')
v2 = tf.get_variable(name='v2', initializer=2)
v3 = tf.Variable(3, name='v3', trainable=False)

print(tf.get_variable_scope().name)       #
print(tf.GraphKeys.VARIABLES)           # variables
print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables

### 获取全部变量 key=variables,scope=None
for j in tf.get_collection(tf.GraphKeys.VARIABLES):
    print(j)
# <tf.Variable 'v1:0' shape=() dtype=int32_ref>
# <tf.Variable 'v2:0' shape=() dtype=int32_ref>
# <tf.Variable 'v3:0' shape=() dtype=int32_ref>

### 获取全部可训练变量 key=trainable_variables,scope=None
for k in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    print(k)
# <tf.Variable 'v1:0' shape=() dtype=int32_ref>
# <tf.Variable 'v2:0' shape=() dtype=int32_ref>

### 获取全部可训练变量 key=trainable_variables,scope=None,等价于上个操作
for m in tf.get_collection('trainable_variables'):
    print(m)
# <tf.Variable 'v1:0' shape=() dtype=int32_ref>
# <tf.Variable 'v2:0' shape=() dtype=int32_ref>

示例2:指定作用域,接上例

### 增加一个作用域
with tf.name_scope('test') as test:
    v4 = tf.Variable(4, name='v4')
    v5 = tf.Variable(5, name='v5', trainable=False)

### 获取全部可训练变量 key=trainable_variables,scope=None,包括新的作用域
for s in tf.get_collection('trainable_variables'):
    print(s)
# <tf.Variable 'v1:0' shape=() dtype=int32_ref>
# <tf.Variable 'v2:0' shape=() dtype=int32_ref>
# <tf.Variable 'test/v4:0' shape=() dtype=int32_ref>

### 获取指定作用域下的可训练变量 key=trainable_variables,scope=test
for t in tf.get_collection('trainable_variables', test):
    print(t)
# <tf.Variable 'test/v4:0' shape=() dtype=int32_ref>

add_n(inputs, name=None)

很简单了,上面的例子中有用到

参考资料:

https://blog.csdn.net/uestc_c2_403/article/details/72415791

https://blog.csdn.net/nini_coded/article/details/80528466

猜你喜欢

转载自www.cnblogs.com/yanshw/p/12435071.html