TensorFlow两种方法创建变量

代码: 

import tensorflow as tf
input = [[1, 2, 3], [4, 5, 6]]
graph = tf.compat.v1.Graph()
# 通过tf.Variable对象创建变量
with graph.as_default():
    input_tf = tf.Variable(input, dtype=tf.float32, name="input")
    print("input_tf shape: ", input_tf.get_shape().as_list())
    print("input_tf dtype: ", input_tf.dtype)

with tf.compat.v1.Session(graph=graph) as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    print("input_tf value :\n", sess.run(input_tf))
print()

# 通过tf.get_variable对象创建变量
graph1 = tf.compat.v1.Graph()
# 创建变量初始化器
initializer = tf.compat.v1.constant_initializer(input)
with graph1.as_default():
    input1_tf = tf.compat.v1.get_variable(name="input", shape=[2, 3], initializer=initializer)
    print("input1_tf shape: ", input1_tf.get_shape().as_list())
    print("input1_tf dtype: ", input1_tf.dtype)
with tf.compat.v1.Session(graph=graph1) as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    print("input1_tf value :\n", sess.run(input1_tf))

输出:

input_tf shape:  [2, 3]
input_tf dtype:  <dtype: 'float32'>
input_tf value :
 [[1. 2. 3.]
 [4. 5. 6.]]

input1_tf shape:  [2, 3]
input1_tf dtype:  <dtype: 'float32'>
input1_tf value :
 [[1. 2. 3.]
 [4. 5. 6.]]
发布了105 篇原创文章 · 获赞 17 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/qq_38890412/article/details/104065589