tensorflow 训练模型的保存和提取



tensorflow笔记:模型的保存与训练过程可视化



保存与读取模型

在使用tf来训练模型的时候,难免会出现中断的情况。这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始。好在tf官方提供了保存和读取模型的方法。

保存模型的方法:

# 之前是各种构建模型graph的操作(矩阵相乘,sigmoid等等....)

saver = tf.train.Saver() # 生成saver with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 先对模型初始化 # 然后将数据丢入模型进行训练blablabla # 训练完以后,使用saver.save 来保存 saver.save(sess, "save_path/file_name") #file_name如果不存在的话,会自动创建

将模型保存好以后,载入也比较方便,如下所示:

saver = tf.train.Saver()

with tf.Session() as sess:
    #参数可以进行初始化,也可不进行初始化。即使初始化了,初始化的值也会被restore的值给覆盖 sess.run(tf.global_variables_initializer()) saver.restore(sess, "save_path/file_name") #会将已经保存的变量值resotre到 变量中。

猜你喜欢

转载自blog.csdn.net/zhouguangfei0717/article/details/80759450