在实际中,通常需要将经过大量训练的较好模型参数保存起来,在实际应用以训练好的模型进行预测。
TensorFlow中提供了模型保存的模块 tensorflow.train.Saver()
1. 导入tensorflow模块 import tensorflow as tf
2. 创建模型保存的Saver对象 saver = tf.train.Saver
3. 保存训练好的模型,设置模型保存的路径 checkpoint_dir = './model/' , 其中model是当前路径下保存模型的文件夹名称
saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step) ,model.ckpt-step 是模型的文件名,step是迭代次数。
需注意的是,最后一次迭代的训练模型有可能不是准确度最高的一次,如果想保存迭代中准确度最高的一次,需要添加判断。
在迭代训练前设置初始最大准确度 max_acc = 0
在每次迭代中进行判断
checkpoint_dir = './model/'
if val_acc > max_acc:
max_acc = val_acc
saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step)
4. 用已保存的模型进行测试
model_file = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, model_file)
output = sess.run(pre_result, feed_dict={x: test_x})