Tensorflow学习笔记——tf.Variable()和tf.get_variable()

tf.Variable()

tf.Variable是一个Variable类
通过variable维持图graph的状态,以便在sess.run()中执行,可以用Variable类创建一个实例在图中增加变量

tf.Variable(
initial_value=None, 
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None
)
  • initial_value
    Tensor或可转换为Tensor的Python对象,它是Variable的初始值。除非validate_shape设置为False,否则初始值必须具有指定的形状。也可以是一个可调用,没有参数,在调用时返回初始值。在这种情况下,必须指定dtype。
  • name
    变量的可选名称。默认为“Variable”并自动获取。
  • dtype
    如果设置,则initial_value将转换为给定类型。如果为None,则保留数据类型(如果initial_value是Tensor),或者convert_to_tensor将决定。

一般常用的参数包括初始化值和名称(是该变量的唯一索引)
在使用变量之前必须要进行初始化

tf.get_variable

获取一个已经存在的变量或者创建一个新的变量

get_variable(
	name,
	shape=None,
	dtype=None,
	initializer=None,
	regularizer=None,
	trainable=True,
	collections=None,
	caching_device=None,
	partitioner=None,
	validate_shape=True,
	use_resource=None,
	custom_getter=None,
	constraint=None
)
  • name
    新变量或现有变量的名称
  • shape
    新变量或现有变量的形状
  • dtype
    新变量或现有变量的类型
  • initializer
    如果创建了则用它初始化变量

区别

  1. 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
  2. 基于这两个函数的特性,当我们需要共享变量的时候,需要使用tf.get_variable()。在其它两种情况下,这两个的用法是一样的。

Reference:
https://blog.csdn.net/MrR1ght/article/details/81228087
https://www.jianshu.com/p/2061b221cd8f?utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation

猜你喜欢

转载自blog.csdn.net/weixin_42018112/article/details/88623100
今日推荐