tensorflow 训练模型的保存 与 读取已保存的模型进行测试

在实际中,通常需要将经过大量训练的较好模型参数保存起来,在实际应用以训练好的模型进行预测。

TensorFlow中提供了模型保存的模块 tensorflow.train.Saver()

1. 导入tensorflow模块                   import tensorflow as tf

2. 创建模型保存的Saver对象      saver = tf.train.Saver

3. 保存训练好的模型,设置模型保存的路径 checkpoint_dir =  './model/' , 其中model是当前路径下保存模型的文件夹名称

     saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step) ,model.ckpt-step 是模型的文件名,step是迭代次数。

    需注意的是,最后一次迭代的训练模型有可能不是准确度最高的一次,如果想保存迭代中准确度最高的一次,需要添加判断。

    在迭代训练前设置初始最大准确度 max_acc = 0

    在每次迭代中进行判断  

checkpoint_dir = './model/'
if val_acc > max_acc:
    max_acc = val_acc
    saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step)

 4. 用已保存的模型进行测试 

     model_file = tf.train.latest_checkpoint(checkpoint_dir)

     saver.restore(sess, model_file)

     output = sess.run(pre_result, feed_dict={x: test_x})

                                                   

猜你喜欢

转载自blog.csdn.net/Muzi_Water/article/details/81974859