tensorflow【2】-Variable 详解

Variable 的主要作用是维护特定节点的状态,如深度学习模型参数

创建

tf.Variable 是最常用的创建变量的方法

class VariableV1(Variable):
    def __init__(self,  
           initial_value=None,        # -- 变量值
           trainable=None,            # 该变量是否需要训练,或者说是否能被优化器更新
           collections=None,
           validate_shape=True,
           caching_device=None,
           name=None,                 # -- 变量的名字
           variable_def=None,
           dtype=None,                # -- 变量类型
           expected_shape=None,
           import_scope=None,
           constraint=None,
           use_resource=None,
           synchronization=VariableSynchronization.AUTO,
           aggregation=VariableAggregation.NONE,
           shape=None):               # -- 变量尺寸
        pass

tf.Variable 是操作 (op),返回值是 Variable;

d1 = tf.Variable(2)
d2 = tf.Variable(3, dtype=tf.int32, name='int')
d3 = tf.Variable(4., dtype=tf.float32, name='float')
d4 = tf.add(d1, d2)
d5 = d1 + d2
# d6 = tf.add(d1, d3)     ### 不同类型的数据不能运算

init = tf.global_variables_initializer()        ### 变量必须初始化

sess1 = tf.Session()
sess1.run(init)
print(sess1.run(d4))        # 5
print(sess1.run(d5))        # 5
# print(sess1.run(d6))
print(type(d5))             # <class 'tensorflow.python.framework.ops.Tensor'>

另一种创建变量的方法 tf.get_variable

d1 = tf.get_variable('d1', shape=[2, 3], initializer=tf.ones_initializer)
d2 = tf.get_variable('d2', shape=[3, 2], initializer=tf.zeros_initializer)
sess3 = tf.Session()
sess3.run(tf.global_variables_initializer())
print(sess3.run(d1))
# [[1. 1. 1.]
#  [1. 1. 1.]]
print(sess3.run(d2))

初始化

Variable 在参与计算之前必须初始化, 两种方式

d1 = tf.Variable(1)
print(d1)       # <tf.Variable 'Variable:0' shape=() dtype=int32_ref>
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(d1))     # 1

###
d2 = tf.Variable(1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()

如果初始化后又创造了新变量,需要重新初始化

初始化之前只是个节点

属性方法

Variable 所有属性方法如下

['SaveSliceInfo', 'aggregation', 'assign', 'assign_add', 'assign_sub', 'batch_scatter_update', 'constraint', 'count_up_to', 'device', 'dtype', 'eval', 'from_proto', 'gather_nd', 'get_shape', 
'graph', 'initial_value', 'initialized_value', 'initializer', 'load', 'name', 'op', 'read_value', 'scatter_add', 'scatter_nd_add', 'scatter_nd_sub', 'scatter_nd_update', 'scatter_sub', 
'scatter_update', 'set_shape', 'shape', 'sparse_read', 'synchronization', 'to_proto', 'trainable', 'value']

变量名

节点名,也是变量名,如果创建 Variable 时显式的设置了 name,则取该 name,如果没有,则以 Variable_1 格式递增下标

d1 = tf.Variable(tf.zeros(2,2))
d2 = tf.Variable(2., dtype=tf.float32, name='d2')
d3 = tf.Variable(3)
print(d1)               # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
print(d1.op.name)       # Variable_1
print(d2.op.name)       # d2
print(d3.op.name)       # Variable_2

内存机制

tf.Variable 创建的变量与张量一样,可以作为操作的输入和输出,不同之处在于:

1. 张量的生命周期通常依赖计算的完成而结束,内存随即释放

2. 变量常驻内存,随计算同步更新,不随计算结束而结束

d1 = tf.Variable(2.)
d2 = tf.constant(42.)
print(d2)       # Tensor("Const:0", shape=(), dtype=float32)
d3 = tf.assign_add(d1, 1.)
d4 = tf.add(d2, 1.)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(2):
        print(sess.run(d3))     # 3.0,  4.0     ### 循环中 variable 状态持续更新,内存不释放
        print(sess.run(d4))     # 43.0, 43.0    ### 循环中 constant 状态不更新,内存实时释放
    print(d1, sess.run(d1))     # <tf.Variable 'Variable:0' shape=() dtype=float32_ref> 4.0     ### 循环结束之后 variable 状态保持,内存未释放
    print(d2, sess.run(d2))     # Tensor("Const:0", shape=(), dtype=float32) 42.0               ### 循环结束之后 constant 状态恢复,内存释放

Variable 的这个特性常用于模型参数的迭代

Variable 赋值

变量的赋值不能直接用 =,有几种方式:tf.assign、var.assign、tf.assign_add

d1 = tf.Variable(2)
d2 = tf.Variable(3, dtype=tf.int32, name='int')
d3 = tf.Variable(4., dtype=tf.float32, name='float')
## method1
# d4 = tf.assign(d2, d3)      ### 两个变量数据类型要一致
d5 = tf.assign(d2, d1)
# d7 = tf.assign(d6, d3)      ### 被赋值的变量必须事先存在
## method2
d8 = d2.assign(100)
## method3:加个数并赋值
d9 = tf.assign_add(d2, 50)

with tf.Session() as sess2:
    tf.global_variables_initializer().run()
    print(sess2.run(d5))        # 2     d5 被赋值了,等于 d2
    print(sess2.run(d2))        # 2     真正的 d2 也变了
    print(sess2.run(d8))        # 100
    print(sess2.run(d9))        # 150
    print(sess2.run(tf.assign_add(d2, 3)))  # 153
    print(sess2.run(tf.assign(d2, 3)))  # 3
    print(sess2.run(tf.assign(d2, d1))) # 2

trainable

trainable 属性指定变量是否参与训练,或者说是否能被优化器更新,类似于 PyTorch 中的 requires_grad;

False 代表不参与训练,默认为 True;

trainable 为只读属性,只在创建 Variable 时生效,后期无法更改;

在创建优化器 Optimizer 的 minimize 张量时,tf 会把所有可训练的 Variable 收集到 trainable_variables 中,此后增加或者删除 可训练的变量,trainable_variables 不会变化;

x = tf.Variable(3.0, dtype=tf.float32, trainable=False)     ### x 的 trainable 为 F,不参与训练
y = tf.Variable(13.0, dtype=tf.float32)         ### 参与训练
train_op = tf.train.AdamOptimizer(0.01).minimize(tf.abs(y - x))
with tf.Session()as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(5):
        _, xx, yy = sess.run([train_op, x, y])
        print('epoch', i, xx, yy)  # 观察 x 和 y 的变化

# epoch 0 3.0 12.99
# epoch 1 3.0 12.98
# epoch 2 3.0 12.969999
# epoch 3 3.0 12.959999
# epoch 4 3.0 12.949999

tf.trainable_variables and tf.all_variables

tf.trainable_variables()   返回的是 所有需要训练的变量列表

tf.all_variables()      返回的是 所有变量的列表

v1 = tf.Variable(0, name='v1')
v2 = tf.Variable(tf.constant(5, shape=[1], dtype=tf.float32), name='v2')
global_step = tf.Variable(6, name='global_step', trainable=False)       # 声明不是训练变量

for ele1 in tf.trainable_variables():
    print(ele1.name)
# v1:0
# v2:0

for ele2 in tf.all_variables():
    print(ele2.name)
# v1:0
# v2:0
# global_step:0

Variable 的保存和加载

保存和加载都需要创建 Saver 对象,然后调用 save 保存 和 restore 加载

from tensorflow.core.protobuf import saver_pb2
class Saver(object):
    def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None):
        pass
    def save(self,
           sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True,
           strip_default_attrs=False,
           save_debug_info=False):
        pass
    def restore(self, sess, save_path):
        """Restores previously saved variables."""
        pass

 

创建 Saver 对象

var_list:被保存或者加载的 variable,该参数的取值有几种形式:

  1. 默认为 None,即针对所有 Variable
  2. list 格式,指定 variable,variable 的命名默认为 v1、v2...
  3. dict 格式,指定 name 和 variable

save 方法

save_path:指定存储路径,一般以 ckpt(checkpoint) 结尾

global_step:指定全局阶段,实际上就是个标记,通常是用一个数字指定 variable 在哪个阶段保存的,这个数字位于 filename 之后,具体见例子

restore 方法

save_path:注意这个路径要看 save 时的 global_step

d1 = tf.Variable(1.)
d2 = tf.Variable(2., dtype=tf.float32, name='d2')
init = tf.global_variables_initializer()

### 初始化 Saver 对象
saver = tf.train.Saver()            ### 保存所有变量
saver1 = tf.train.Saver([d1, d2])       ### list 指定变量,变量名默认为 v1 v2 递增
saver2 = tf.train.Saver({'v1': d1, 'v2':d2})      ### dict 指定变量和变量名

with tf.Session() as sess:
    sess.run(init)
    ### save 方法保存变量
    saver.save(sess, './var/all.ckpt')
    saver1.save(sess, './var/list.ckpt', global_step=0)
    print(saver2.save(sess, './var/dict.ckpt', global_step=1))      # ./var/dict.ckpt-1
    sess.run(tf.assign_add(d2, 3.))     ### 保存之后再改变
    print(sess.run(d2))     # 5.0

    ### 加载变量 1:同一个 saver、sess
    saver2.restore(sess, './var/dict.ckpt-1')
    print(sess.run(d2))     # 2.0       ### 加载的是改变前的值,说明保存成功

### 加载变量 2:同一个 saver,不同的 sess
with tf.Session() as sess:
    saver2.restore(sess, './var/dict.ckpt-1')
    print(sess.run(d2))     # 2.0

### 加载变量 3:不同的 saver,不同的 sess
saver2 = tf.train.Saver({'v2':d2})
with tf.Session() as sess:
    saver2.restore(sess, './var/dict.ckpt-1')
    print(sess.run(d2))     # 2.0

可见,保存与加载相互独立

上述代码保存 variable 结果如下

1. global_step 被加到 filename 之后

2. save 会生成 4 个文件 data、index、meta、checkpoint

  • data:存放模型参数
  • meta:存放计算图
  • checkpoint:记录模型存储的路径,model_checkpoint_path 代表最新的模型存储路径,all_model_checkpoint_paths 代表所有模型的存储路径

3. 最多只保存近 5 次的存储

4. 多次保存只有一个 checkpoint

参考资料:

https://www.cnblogs.com/weiyinfu/p/9973022.html  tensorflow动态设置trainable

猜你喜欢

转载自www.cnblogs.com/yanshw/p/12341295.html