Tensorflow模型的保存与恢复

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Irving_zhang/article/details/79081694

最近在写对话生成的代码时,遇到一个问题就是在预测阶段,对于相同的输入,每一次生成的文本都不一样,而且生成的结果乱七八糟。因此定位到是训练好的模型没有restore,特此记录一下TensorFlow中模型的保存与恢复问题,即tf.train.saver函数的使用。

创建Saver

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。

保存模型

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'save/model.ckpt',global_step=step)

第一个参数sess,这个就是当前会话,记录了训练的变量值。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中,例如:

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

在实验中,按照次序保存的最后几代可能并不是验证精度最高的一代,因此我们并不想默认保存最后几代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver()
max_acc=0
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'save/model.ckpt',global_step=i+1)

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('save/acc.txt','w')
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'save/model.ckpt',global_step=i+1)
f.close()
sess.close()

模型恢复

模型恢复用的是tf.train.restore()函数,它需要两个参数restore(sess, save_path),sess表示当前会话,之前保存的结果将被加载入这个会话,save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('save/')
saver.restore(sess,model_file)

注意:这里的latest_checkpoint函数的参数表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新保存结果的命名。

模型恢复可以用在对话生成的预测阶段,训练好了模型参数直接拿来用。给定输入,然后得到输出。也可以用在断点处开始训练,这样就不用每次训练都从头开始:

# 如果已经保存过模型,导入上次的模型
if os.path.exists(ckpt_path + "checkpoint"):
    print("Restoring Variables from Checkpoint...")
    model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
    # 输出模型参数,可自定义
    # last_valid_cost, precision, recall, last_f1 = valid_epoch(data_valid_path, sess, model)
    # print(' valid cost=%g; p=%g, r=%g, f1=%g' % (last_valid_cost, precision, recall, last_f1))
else:
    print('Initializing Variables...')
    sess.run(tf.global_variables_initializer())

猜你喜欢

转载自blog.csdn.net/Irving_zhang/article/details/79081694