代码:
import tensorflow as tf
input = [[1, 2, 3], [4, 5, 6]]
graph = tf.compat.v1.Graph()
# 通过tf.Variable对象创建变量
with graph.as_default():
input_tf = tf.Variable(input, dtype=tf.float32, name="input")
print("input_tf shape: ", input_tf.get_shape().as_list())
print("input_tf dtype: ", input_tf.dtype)
with tf.compat.v1.Session(graph=graph) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print("input_tf value :\n", sess.run(input_tf))
print()
# 通过tf.get_variable对象创建变量
graph1 = tf.compat.v1.Graph()
# 创建变量初始化器
initializer = tf.compat.v1.constant_initializer(input)
with graph1.as_default():
input1_tf = tf.compat.v1.get_variable(name="input", shape=[2, 3], initializer=initializer)
print("input1_tf shape: ", input1_tf.get_shape().as_list())
print("input1_tf dtype: ", input1_tf.dtype)
with tf.compat.v1.Session(graph=graph1) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print("input1_tf value :\n", sess.run(input1_tf))
输出:
input_tf shape: [2, 3]
input_tf dtype: <dtype: 'float32'>
input_tf value :
[[1. 2. 3.]
[4. 5. 6.]]
input1_tf shape: [2, 3]
input1_tf dtype: <dtype: 'float32'>
input1_tf value :
[[1. 2. 3.]
[4. 5. 6.]]