TensorFlow1.x入门(11)——模型的保存与恢复

系列文章

本教程有同步的github地址

0. 统领篇

1. 计算图的创建与启动

2. 变量的定义及其操作

3. Feed与Fetch

4. 线性回归

5. 构建非线性回归模型

6. 简单分类问题

7. Dropout与优化器

8. 手动调整学习率与TensorBoard

9. 卷积神经网络(CNN)

10. 循环神经网络(RNN)

11. 模型的保存与恢复

模型的保存与恢复

引言

利用TensorFlow训练好模型可以对测试集的数据进行预测,用于评估模型的好坏。但是每次执行一个预测任务时,均从头训练一下模型,则会耗费大量的时间与资源,并且有可能结果不能完全的复现。
所以TensorFlow提供了模型的保存与恢复的接口,当你训练好模型后,可以将它持久化在本地,再次使用时可以直接恢复进行预测,不需要再重新训练。

知识点

saver=tf.train.Saver()定义一个保存模型的对象,也是固定写法。
saver.save(sess, save_path=r"...")写在session中用于在每次迭代结束后保存一下模型,也可以保存效果最好的模型。
saver.restore(sess, save_path=r"...")同样写在session中用于恢复模型的参数,其中save_path就是模型保存的地址。

示例

#%% md
# 模型的保存与恢复
#%%
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#%% md
加载数据
#%%
mnist = input_data.read_data_sets("MNIST", one_hot=True)
#%% md
设置参数batch_size的大小,计算迭代的总批次
#%%
batch_size = 100
n_batches = mnist.train.num_examples // batch_size
#%% md
构建网络
#%%
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
#%%
w = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
#%% md
预测输出
#%%
prediction = tf.nn.softmax(tf.matmul(x, w) + b)

#%% md
定义损失函数
#%%
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#%% md
定义优化器
#%%
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#%% md
计算正确率
#%%
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#%% md
定义初始化的init
#%%
init = tf.global_variables_initializer()
#%% md
定义保存的对象
#%%
saver = tf.train.Saver()
#%% md
训练模型并保存
#%%
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(100):
        for batch in range(n_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run([train_step], {x:batch_xs, y:batch_ys})
        saver.save(sess, save_path='saved_model/mymodel')
        acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
        print("Iter: " + str(epoch) + " Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%% md
恢复模型进行对比
1. 未恢复参数的模型效果
2. 完全恢复模型参数的效果
#%%
with tf.Session() as sess:
    sess.run(init)
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, 'saved_model/mymodel')
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%

猜你喜欢

转载自blog.csdn.net/qq_19672707/article/details/106082917
今日推荐