Tensorflow_03_Save and Restore 储存和载入

Brief 概述

在理解了建构神经网络的大致函数用途,且熟悉了神经网络原理后,我们已经大致具备可以编写神经网络的能力了,在涉及比较复杂的神经网络结构前,还有一件重要的事情需要了解,那就是中途存档和事后读取的函数,它攸关到庞大的算力和时间投入后产出的结果是否能够被再次使用,是一个绝对必须弄清楚的环节,因此本节主要围绕一个主题:

  • Checkpoint 检查点

它如同会议记录一般,可以针对性的把训练过程记录下来,除了避免前功尽弃之外,还可以让我们有机会一窥训练过程的究竟,从演变过程中寻找改善算法的方案。

p.s. 关于设备如果手边没有,非常建议直接使用云端的计算服务,如 AWS, FloydHub 等平台

其他在深度学习中常用的函数定义方法可以参考上一篇文章: Tensorflow_02_Useful Functions 常用函数大全

Checkpoint 检查点

在初期一般训练模型简单且训练速度极快,对于参数中间变化的过程我们也不会特别在意,但是到了复杂的神经网络训练过程时,为参数训练过程中途存档这件事情就会变得非常重要,这就像我们玩电玩游戏闯关的时候,希望最好能够中途存档,如果死在半路上可以直接从存档的地方恢复游戏。

Save checkpoints 储存检查点

同理深度学习训练过程,一般训练耗费时间约为几天乃至一周,如果中途发生机器停机或是任何意外导致训练终止,我们可以从检查点记录的地方重新开始。抑或者如果我们要分析训练过程中参数的变化走势,检查点也非常实用。使用的类为:

  • tf.train.Saver(max_to_keep=None) 档名: 「.ckpt」
  • .Saver({’save_w‘: weight}) 括弧中可以用字典的方式指定只要储存哪一个参数
  • max_to_keep=None: 最多有几个检查点被保存下来,如果是 None 或是 0 则表示全保存
  • keep_checkpoint_every_n_hours=1: 设置几个小时保存一次检查点

变量以二进制的方式被存在名为 .ckpt 的档案中,内容包含了变量的名字和对应张量的数值,创建一个该类的示例,就可以呼叫里面储存与载入储存文件内容的函数方法:

  • tf.train.Saver().save(sess, './file_directory', global_step=int(num))
  • sess: 表示要储存哪个绘话里面的参数
  • './file_directory/file_name': 储存的路径沿着执行训练的 .py 文档路径位置继续指定路径,如果文件夹不存在指定目录的话,它会自行创建。官网教程中建议档名后面连同后缀一起加上,如下代码...
  • global_step:指定一个数字,将一起被纳入检查点文件命名中

!!! 储存这些参数的时候特别需要注意申明清楚参数的数据类型非常重要,它攸关到之后要呼叫回这些参数的时候是否顺利,如果没有事先申明清楚,大概率上会有错误发生。

下面代码展示如何保存检查点:

import numpy as np
import tensorflow as tf

x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

weight = tf.Variable(tf.random_uniform(shape=[1], minval=-1.0, maxval=1.0), 
                     dtype=np.float32)#, name='weight')
bias = tf.Variable(tf.zeros(shape=[1]), dtype=np.float32, name='bias')
y = weight * x_data + bias

loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
training = optimizer.minimize(loss)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# The instance is created to call the method saving checkpoint
saver = tf.train.Saver()
save_w = tf.train.Saver({'a_name': weight})

for step in range(101):
    sess.run(training)
    if step % 10 == 0:
        print('Round {}, weight: {}, bias: {}'
              .format(step, sess.run(weight[0]), sess.run(bias[0])))
        saver.save(sess, './checkpoint/linear.ckpt', global_step=step)
        save_w.save(sess, './weight/linear.ckpt', global_step=step)
        
saver.save(sess, './checkpoint/linear.ckpt')
sess.close()


### ----- Result is shown below ----- ###
Round 0, weight: 0.6087742447853088, bias: 0.031045857816934586
Round 10, weight: 0.3177388906478882, bias: 0.18408644199371338
Round 20, weight: 0.19332920014858246, bias: 0.2503160834312439
Round 30, weight: 0.14000359177589417, bias: 0.27870404720306396
Round 40, weight: 0.11714668571949005, bias: 0.2908719480037689
Round 50, weight: 0.10734956711530685, bias: 0.29608744382858276
Round 60, weight: 0.10315024852752686, bias: 0.29832297563552856
Round 70, weight: 0.10135028511285782, bias: 0.29928117990493774
Round 80, weight: 0.10057878494262695, bias: 0.29969191551208496
Round 90, weight: 0.10024808347225189, bias: 0.2998679280281067
Round 100, weight: 0.10010634362697601, bias: 0.2999434173107147

检查点的路径设置需要使用 「./.../.../...」 的格式去写路径,尤其是开头的 ./ 必须加上,否则在某些平台上会出现错误,等代码运行完毕后在下面 .py 文档执行路径下出现我们设置的储存文件夹和文件名称,如下图:

扫描二维码关注公众号,回复: 3410558 查看本文章

在默认情况下 tf.train.Saver(max_to_keep=5) 是我们无特别设定的结果,因此只会保存离最近更新的五个参数,其他的参数将即自动删除。

Read checkpoints 读取检查点

文件存好之后接下来就是读取上图中储存的文件,储存在文件里面的数据是一个原封不动的 tf.Variable() 物件,有着与储存前一模一样的名字和属性,甚至在呼叫回该储存的变量时也不用初始化,是一个非常全面的保存结果, 只是需要记得: 「同样变量名的物件需要事先存在在代码中, 并且数据类型和长相必须一模一样。

读取的方式也很直观,同样的创建一个 tf.train.Saver() 示例,并用该示例里面的方法 .restore() 完成读取,读取完毕后储存的参数就回像起死回生一般重新回到我们的代码中。

  • tf.train.Saver().restore(sess, 'file_directory')
  • sess: 表示我们希望把该储存的内容重新叫回哪一个绘话中
  • './file_directory/file_name': 表示我们要呼叫的该存档文件

p.s. 如果在储存过程中有加上 global_step 参数,呼叫文档名的时候就必须一起把数字也加上去,如下代码。

呼叫储存文件的时候有以下三种情况:

  1. 最直接: 使用 tf.train.Saver() 创建示例后,呼叫 .restore() 方法配合对应名字,成功回到训练中途的记录
  2. 第一个方法受阻: 绕道使用 .meta 储存文件,并使用 tf.import_meta_graph() 示例的 .restore() 方法,同样可以成功回到训练中途的记录
  3. 呼叫只储存部分参数的记录档: 创建一个示例前先在 tf.train.Saver() 括弧中使用字典形式声明好当时部分储存的时候对应一模一样名字的字典键和参数名,再用 .restore() 方法成功回到训练中途的记录

详细代码如下演示:

import tensorflow as tf

# tf.reset_default_graph()
weight = tf.Variable([33], dtype=tf.float32)#, name='weight')
bias = tf.Variable([3], dtype=tf.float32, name='bias')

saver = tf.train.Saver()
# saver = tf.train.import_meta_graph('./checkpoint/linear.ckpt.meta')
saver_2 = tf.train.Saver({'a_name': weight})
init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
path1 = saver.restore(sess, './checkpoint/linear.ckpt-90')
path2 = saver_2.restore(sess, './weight/linear.ckpt-60')
print(sess.run(weight))
print(sess.run(bias))
sess.close()

''' 
print(sess.run(biases))

### ----- Result as follow ----- ###
FailedPreconditionError: 
Attempting to use uninitialized value Variable
[[Node: _retval_Variable_0_0 = _Retval[T=DT_FLOAT, index=0, 
  _device="/job:localhost/replica:0/task:0/device:CPU:0"](Variable)]]
'''


### ----- Result is shown below ----- ###
INFO:tensorflow:Restoring parameters from ./checkpoint/linear.ckpt-90
INFO:tensorflow:Restoring parameters from ./weight/linear.ckpt-60
[0.10315025]
[0.29986793]

可以观察到,如果没有成功导入内容, sess.run() 执行一个参数的时候就会被通知该参数没有初始化,需要特别注意。另外如果重复导入同样的值到该代码中,那么该值以最后一次导入为主,如上面代码中的 weight,最近导入的 60 个回合训练的 weight 值比训练 90 个回合的 bias 值还要不准得多。

  • tf.train.latest_checkpoint('./.../...')
  • more to update

!! 重要 !!  导入没有成功, 报错 >> ValueError: At least two variables have the same name: Variable

花了一整个晚上找方法的错误,原因还是在于 tf.Variable() 的格式没有完全一样,前面只专注在数据格式上面,但是其节点名称必须也完全一致才可以! 如果表明名称 name='a_name', 那么就都不要写,如果表明了名称,那就必须完全一致才行!

猜你喜欢

转载自blog.csdn.net/Kuo_Jun_Lin/article/details/81772952