tf.add_to_collection() 和tf.get_collection()语句的组合使用旨在更好的管理同一类型(或者意义)的张量。
tf.add_to_collection():把张量发到一起,并用同一个命名空间命名多个张量,将多个张量组合成一个list,没有返回值。
tf.get_collection(name) :将之前通过tf.add_to_collection()语句添加的张量集合,通过name参数提取出来,返回一个list。
代码解析
import tensorflow as tf
value1=tf.get_variable(name = 'value1',shape=[3],initializer=tf.ones_initializer())
value2=tf.get_variable(name = 'value2',shape=[3],initializer=tf.random_uniform_initializer(maxval=-1,minval=1,seed=0))
loss1 = tf.get_variable(name = 'loss1',shape = [1],initializer=tf.constant_initializer(0))
loss2 = tf.get_variable(name = 'loss2',shape = [1],initializer=tf.constant_initializer(0))
tf.add_to_collection('value',value1)
tf.add_to_collection('value',value2)
tf.add_to_collection('loss',loss1)
tf.add_to_collection('loss',loss2)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
value = tf.get_collection(value)
print value
print value[0].eval
print value[1].eval
loss = tf.get_collection(loss)
total_collection_num = tf.add_n(loss)
print loss
print total_collection_num