TensorFlow训练模型的保存与加载

在构建TensorFlow模型训练过程中有可能会出错,或者训练时间比较长的情况下,我们希望把训练的模型参数数据,结构保存起来,所以模型保存是很有必要的,等我们测试模型的时候可以直接加载保存的模型,测试我们的数据。

TensorFlow的模型计算是以图为单位,图中含有多个节点OP,每个节点OP包含计算变量等,使用tf.train.Saver()保存模型中的所有变量以及图结构。Saver构造器使用如下:

saver = tf.train.Saver(max_to_keep=10)

max_to_keep参数表示保存的模型个数,默认是5个,下面是模型保存的通用代码:

saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session()
for i in range(2000):
    #训练过程
    #train()
    saver.save(sess, CHECK_POINT_DIR+'/model', global_step = i)

其中CHECK_POINT_DIR是模型保存的目录,model是保存的文件名,文件名+global_step,指示哪一次保存的模型

模型的加载过程代码如下:

saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
	saver.restore(sess, ckpt.model_checkpoint_path)

其中CKPT_DIR是模型保存的目录,ckpt.model_checkpoint_path是保存的checkpoint文件的路径。

接下来看看模型保存路径下的文件:



meta文件是模型定义的内容;ckpt(或data和index)文件是保存的模型数据

猜你喜欢

转载自blog.csdn.net/jiangyingfeng/article/details/81062379