【tensorflow】:tf.add_to_collection() & tf.get_collection()

tf.add_to_collection() 和tf.get_collection()语句的组合使用旨在更好的管理同一类型(或者意义)的张量。

tf.add_to_collection():把张量发到一起,并用同一个命名空间命名多个张量,将多个张量组合成一个list,没有返回值。

tf.get_collection(name) :将之前通过tf.add_to_collection()语句添加的张量集合,通过name参数提取出来,返回一个list。

代码解析

#_*_coding:utf-8_*_
import tensorflow as tf

value1=tf.get_variable(name = 'value1',shape=[3],initializer=tf.ones_initializer())
value2=tf.get_variable(name = 'value2',shape=[3],initializer=tf.random_uniform_initializer(maxval=-1,minval=1,seed=0))

#第二个类别的tensor
loss1 = tf.get_variable(name = 'loss1',shape = [1],initializer=tf.constant_initializer(0))
loss2 = tf.get_variable(name = 'loss2',shape = [1],initializer=tf.constant_initializer(0))

#利用tf.add_to_collection()管理上述两个类别的tensor(张量)
tf.add_to_collection('value',value1)
tf.add_to_collection('value',value2)

tf.add_to_collection('loss',loss1)
tf.add_to_collection('loss',loss2)


with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    #利用tf.get_collection(name)来调取上述存入的两个tensor
    value = tf.get_collection(value)
    #可以利用eval函数来输出两个tensor的值
    print value
    print value[0].eval
    print value[1].eval


    #利用loss这个name来管理这两个变量
    #利用tf.add_n这个函数两统计collection中的tensor数量
    loss = tf.get_collection(loss)#返回的就是个列表
    total_collection_num = tf.add_n(loss)
    print loss
    print total_collection_num

猜你喜欢

转载自blog.csdn.net/qiu931110/article/details/80136595