tf.train.Saver()

tf.train.Saver() 保存和加载模型

saver = tf.train.Saver()

saver.save(sess,  '路径 + 模型文件名')

import tensorflow as tf
import os
import numpy as np

a = tf.Variable(1., tf.float32)
b = tf.Variable(2., tf.float32)
num = 10

saver = tf.train.Saver()
init = tf.global_variables_initializer()
save_dir = './model/'

with tf.Session() as sess:
    for step in np.arange(num):
        sess.run(init)
        c = sess.run(tf.add(a, b))
        checkpoint_path = os.path.join(save_dir, 'model.ckpt')
        # 默认最多同时存放 5 个模型
        saver.save(sess, checkpoint_path, global_step=step)

可选参数 global_step : 编号 checkpoint 名字


Tensorflow 会自动生成4个文件

第一个文件为 model.ckpt.meta,保存了 Tensorflow 计算图的结构,可以简单理解为神经网络的网络结构。

model.ckpt.indexmodel.ckpt.data-*****-of-***** 文件保存了所有变量的取值。

最后一个文件为 checkpoint 文件,保存了一个目录下所有的模型文件列表。

import tensorflow as tf
import os
import numpy as np

a = tf.Variable(1., tf.float32)
b = tf.Variable(2., tf.float32)
num = 10

saver = tf.train.Saver()
init = tf.global_variables_initializer()
save_dir = './model/'

with tf.Session() as sess:
    sess.run(init)
    ckpt = tf.train.get_checkpoint_state(save_dir)
    # 载入模型
    saver.restore(sess, ckpt.model_checkpoint_path)
    print("load success")


猜你喜欢

转载自blog.csdn.net/yz19930510/article/details/80324389
今日推荐