tf.GraphKey简介

文章目录


官方文档链接:
https://www.tensorflow.org/api_docs/python/tf/GraphKeys#GLOBAL_VARIABLES

公有以下标准键:
GLOBAL_VARIABLES
LOCAL_VARIABLES
MODEL_VARIABLES
TRAINABLE_VARIABLEStf.Optimizer子类默认优化该类下的变量
SUMMARIES
QUEUE_RUNNERS
MOVING_AVERAGE_VARIABLES
REGULARIZATION_LOSSES
定义了以下标准键,但它们的集合不会像其他许多键那样自动填充:
WEIGHTS
BIASES
ACTIVATIONS

目前能够确定的键的关系如下
在这里插入图片描述

tf.get_collection()可以以list形式获取某个集合。

import tensorflow as tf
# 创建变量的3种方式

# 方式1
# tf.Variable(Tensor)
a = tf.Variable(tf.random_uniform(shape=[2,2], minval=0.0, maxval=1.0, dtype=tf.float32), name="a")

# 方式2
# tf.get_varizble(name=, shape=, initializer=)
b = tf.get_variable("b",
                    shape=[2,2],
                    initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0, dtype=tf.float32))

# 方式3
# 与方式2类似,只是将initializer写到了variable_scope中
with tf.variable_scope("variable___scope", initializer=tf.truncated_normal_initializer(mean=10.0, stddev=1.0, dtype=tf.float32)):
    c = tf.get_variable("c", shape=[2,2])

with tf.Session() as sess:
    # writer = tf.summary.FileWriter("logs_test", sess.graph)
    sess.run(tf.global_variables_initializer())
    # aa=tf.global_variables() # 与下一句等价
    aa=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    # bb=tf.trainable_variables()  # 与下一句等价
    bb=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    print(type(aa))
    print(aa)
    print(type(bb))
    print(bb)
    print("a:\n", a.eval())
    print("b:\n", b.eval())
    print("c:\n", c.eval())
<class 'list'>
[<tf.Variable 'a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'b:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'variable___scope/c:0' shape=(2, 2) dtype=float32_ref>]
<class 'list'>
[<tf.Variable 'a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'b:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'variable___scope/c:0' shape=(2, 2) dtype=float32_ref>]
a:
 [[0.794688   0.50156784]
 [0.7120378  0.08468437]]
b:
 [[ 0.8215166   0.8389353 ]
 [ 0.06688708 -0.4017645 ]]
c:
 [[11.046042   9.249809 ]
 [ 9.4180565 10.788679 ]]

总结

tensorflow的Graph中有很多collection,标准的有大概10个,分别是

公有以下标准键:
GLOBAL_VARIABLES
LOCAL_VARIABLES
MODEL_VARIABLES
TRAINABLE_VARIABLEStf.Optimizer子类默认优化该类下的变量
SUMMARIES
QUEUE_RUNNERS
MOVING_AVERAGE_VARIABLES
REGULARIZATION_LOSSES
定义了以下标准键,但它们的集合不会像其他许多键那样自动填充:
WEIGHTS
BIASES
ACTIVATIONS

在定义变量的时候,这些变量会被自动分配到某些集合中。想要获取集合,可以用tf.get_collection函数
如获取Graph中的GLOBAL_VARIABLES集合,可用
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)或者tf.global_variables()(二者等价)

猜你喜欢

转载自blog.csdn.net/lllxxq141592654/article/details/85329740
tf