如何保存训练模型

我们花了大量的时间去训练这个模型,当我们重新开机时,由刚刚训练得到的比较好的模型参数就全部丢失了,没有办法直接拿来用,只能利用以前记录下来的比较好的那组超参数接着进行模型训练,训练完成后再应用。这样是比较麻烦的,不符合日常的应用场景。

我们需要把训练好的模型做持久化保存,哪怕关机重启也不会丢失,可以把模型重新读取出来以供使用。这么做还有一个好处:当我们在处理一个比较复杂的模型时,需要花费大量时间,有的大型模型可能需要几天甚至几十天,如果训练中发生断电或是需要关机,模型不能保存下来,是比较麻烦的。

这里会提到一个“断点续训”的概念,即不管训练到什么阶段,可以暂停训练,下一次需要的时候,可以从暂停点继续训练。这可以通过模型的保存和还原来实现。

初始化参数和文件目录

在这里插入图片描述
为了做好模型的保存,需要添加一些参数。首先用到了save_step,即模型保存的粒度(训练多少轮保存一次),如果设为1,那么每一轮都会保存。在这里设置为5,即每五轮保存一次。
模型保存在计算机上需要有一个具体的位置,我们建立了一个目录“./ckpt_dir/”,前面的点表示当前文件目录下新建一个子目录。这里会用到os库,所以需要导入进来,接着做一个判断,如果当前目录下不存在这个子目录则创建一个新的子目录。

训练并存储模型

在这里插入图片描述
在训练模型之前,若所有的变量都定义好了,就可以调用tf.train.Saver()去初始化saver,存储模型将通过saver来实现。
在这里插入图片描述
在训练模型中,要加入的代码也不多。橙框的上面部分输出了这一轮的训练结果(损失值以及精确率);橙框中的代码是加入的代码,它表示当这一轮结果需要保存的时候,调用saver的save方法,里面的有两个参数:第一个参数是当前运行的会话,它要把会话中的所有变量的值保存下来,第二个参数是所要保存模型的文件名,这个文件名需要包含具体目录,所以这里利用了os.path.join()函数来合成,第一个参数是刚才定义的子目录,第二个参数是模型的名字,名字中保留了一个记录轮次的格式,这样就可以知道这个文件是第几轮训练的结果。然后输出一个提示,告诉我们已经保存完成。当所有轮次训练完毕后,再通过蓝框中的代码保存最后的结果。

再完整看一下修改后的训练过程(框中为新增部分):在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43485035/article/details/109638202