tensorflow.get_variable

get_variable类似C++结构体
代码:
import tensorflow as tf
import numpy as np

a=tf.get_variable(“a”,shape=[2])
b=tf.get_variable(“b”,shape=[2,2])
c=tf.get_variable(“c”,shape=[2])

tf.add_to_collection(“group1”,a)
tf.add_to_collection(“group1”,b)
tf.add_to_collection(“group2”,c)

print(tf.get_collection(“group1”))
print(tf.get_collection(“group2”))
输出:
[<tf.Variable ‘a:0’ shape=(2,) dtype=float32_ref>, <tf.Variable ‘b:0’ shape=(2, 2) dtype=float32_ref>]
[<tf.Variable ‘c:0’ shape=(2,) dtype=float32_ref>]

猜你喜欢

转载自blog.csdn.net/qq_33345917/article/details/84993293
今日推荐