详解:Tensorflow 模型的保存与加载

Tensorflow 模型的保存与加载

加载之前训练好的模型,继续训练防止程序意外退出拿不到训练结果。
tf.train.saver() 保存模型
tf.train.restore() 模型的加载

(一)Tensorflow模型介绍

通常我们训练好中后都会得到这样几个文件
在这里插入图片描述
(1).meta 文件
是一个协议缓冲区,可以保存完整的 Tensorflow图 即所有 变量,操作,集合 等。此文件具有.meta扩展名
(2).data 文件
.data 文件是一个二进制文件,包括 权重,偏差,渐变和所有其他保存变量的所有值 。.data-00000of00001只是后缀,加载的时候不用写,只写model.ckpt即可。详情见后面
其中,-27150表示第27150次训练得到的结果

(二)模型的保存(保存所有的所有参数的图形和值)

1、 首先要建立一个saver对象:如

saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

max_to_keep 表示保存模型的个数,max_to_keep=5表示保存最新的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver = tf.train.Saver( max_to_keep=0)

但是这样不推荐,一般不这样做,浪费存储空间
2、 创建完saver对象后,使用saver.save()就可以保存训练好的模型了,如:

checkpoint_path='/Users/model/ade20k/model.ckpt'
#checkpoint_path  模型存储地址 后缀为.ckpt
saver.save(self.sess, checkpoint_path,global_step=step)
print('The checkpoint has been created, step: {}'.format(step))#可以打印出来

self.sess是创建的会话,因为所有的变量仅在会话中存在。因此,您必须通过在刚创建的saver对象上调用save方法将模型保存在会话session中。第二个参数设置保存的路径和名字,第三个参数global_step将训练的次数作为后缀加入到模型名字中。
完整的保存程序应该是

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()#创建对象saver
sess = tf.Session()#创建会话
sess.run(tf.global_variables_initializer())#初始化会话所有变量
saver.save(sess, 'my_test_model',global_step=1000)#将第1000次训练的模型保存到会话中

可根据需求,个性化保存训练文件
(1)让我们说,在训练时,我们在每1000次迭代后保存我们的模型,所以.meta文件是第一次创建(第1000次迭代),我们不需要每次都重新创建.meta文件(所以,我们不要t保存.meta文件在2000,3000 …或任何其他迭代)。我们只保存模型以进行进一步的迭代,因为图形不会改变。因此,当我们不想编写元图时,我们使用这个:

saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

(2)如果您只想保留4个最新型号并希望在培训期间每2小时保存一个型号,则可以使用max_to_keep和keep_checkpoint_every_n_hours。

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

(3)最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代

扫描二维码关注公众号,回复: 10852068 查看本文章
saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100)
    batch_xs,batch_ys=mnist.train.next_batch(100)
    sess.run(train_op,feed_dict={x: batch_xs,y_: batch_ys})
    val_loss,val_acc=sess.run([loss,acc],feed_dict={x:mnist.test.images,y_:mnist.test.labels})
    print('epoch:%d, val_loss:%f,val_acc:%f'%(i,val_loss,val_acc))
    if val_acc>max_acc:
        max_acc=val_ac        
        saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

(4)如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

(三)、模型的加载与恢复

如果想使用其他人的预训练模型进行微调,需要做两件事:
(1)创建网络结构
可以通过编写python代码来创建网络,以手动创建每个图层作为原始模型。但是,我们已将网络保存在.meta文件中,我们可以使用tf.train.import()函数重新创建网络,如下所示

new_saver = tf.train.import_meta_graph("/Users/yanni_Z/python/model.ckpt-100.meta")

import_meta_graph将.meta文件中定义的网络加载到当前图形。创建图形/网络,接着还需要加载我们在此图表上训练过的参数值。
(2)加载参数
用restore()函数,它需要两个参数restore(sess, save_path)
tf.train.latest_checkpoint()来自动获取最后一次保存的模型

with tf.Session() as sess:
 new_saver =tf.train.import_meta_graph('my_test_model-1000.meta')
 new_saver.restore(sess,tf.train.latest_checkpoint('./'))

或者加载特定的训练模型

with tf.Session() as sess:
 tf.train.import_meta_graph("/Users/yanni_Z/python/model.ckpt-100.meta")#加载网络结构
 loader.restore(sess,"/Users/yanni_z/python/model/odel.ckpt-100")#载入权重等参数

(四)使用已经恢复的模型

如果想使用已经载入的模型进行 预测、微调,进一步训练。
使用Tensorflow时,先定义一个图表,其中包含示例(训练数据)和一些超参数,如学习率,全局步骤等。使用占位符提供所有训练数据和超参数的标准做法。让我们使用占位符构建一个小型网络并保存它。 或者加载特定的训练模型

import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

当我们想要恢复它时,我们不仅要恢复图形和权重,还要准备一个新的feed_dict,将新的训练数据提供给网络。我们可以通过graph.get_tensor_by_name()方法引用这些保存的操作和占位符变量。

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")
## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我们只想使用不同的数据运行相同的网络,您只需通过feed_dict将新数据传递到网络即可。

import tensorflow as tf
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

如果想在原来的网路图中添加更多操作,该怎么办?当然你也可以这样做。看这里:

import tensorflow as tf
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.

可以恢复旧图形的一部分并对其进行附加以进行微调?当然,您可以通过graph.get_tensor_by_name()方法访问相应的操作,并在其上构建图形。在这里,我们使用元图加载一个vgg预训练网络,并在最后一层将输出数量更改为2,以便使用新数据进行微调。

saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()

参考:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ (英文文章写的超级好)
https://blog.csdn.net/liuxiao214/article/details/79048136

发布了46 篇原创文章 · 获赞 9 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_43826596/article/details/90256020