TensorFlow save and load trained model

For machine learning, especially deep learning DL algorithms, model training may be time-consuming, several hours or days, so if there is a problem with the test module, it is a waste of time to re-run every time, so if training If there is no problem in part, you can directly save the trained model, and then directly load the model next time, and then test it.

The class for saving (save) and loading (restore) models in tensorflow is tf.train.Saver(), where variables are stored as key-values, and all variables are defaulted if no parameters are passed.
Save the model as follows:

import tensorflow as tf
"""
声明variable和op  
初始化op声明
"""
#创建saver对象,它添加了一些op用来save和restore模型参数
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #训练模型过程
    #使用saver提供的简便方法去调用 save op
    saver.save(sess, "save_path/file_name.ckpt")  

The restore function is used to load the model. First, a saver object is created. The restore model is as follows:

import tensorflow as tf
"""
声明variable和op
初始化op声明
"""
#创建saver 对象
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)#可以执行或不执行,restore的值会override初始值
    saver.restore(sess, "save_path/file_name.ckpt") 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325653971&siteId=291194637