tensorflow 模型保存与加载

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 
什么是TF模型:

在训练一个神经网络模型后,你会保存这个模型未来使用或部署到产品中。所以,什么是TF模型?TF模型基本包含网络设计或图,与训练得到的网络参数和变量。因此,TF模型具有两个主要文件: 
a)meta图 
这是一个拟定的缓存,包含了这个TF图完整信息;如所有变量等等。文件以.meta结束。 
b)检查点文件: 
这个文件是一个二进制文件,包含所有权重、偏移、梯度和所有其它存储的变量的值。这个文件以.ckpy结束。然而,TF已经在0.11版本后不再以这个形式了。转而文件包含如下文件 : 
mymodel.data-00000-of-00001 
mymodel.index 
.data文件包含训练变量。 
除此之外 ,TF还包含一个名为“checkpoint”的文件 ,保存最后检查点的文件。 
所以,综上,TF模型包含如下文件 :

  • my_test_model.data-00000-of-00001 
  • my_test_model.index 
  • my_test_model.meta 
  • checkpoint**

2保存一个TF模型 
saver = tf.train.Saver() 
注意,你需要在一个session中保存这个模型 
Python 
1saver.save(sess, ‘my-model-name’) 
完整的例子为:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

如果是在TF模型迭代1000步后保存这个模型,可以指定步数 
saver.save(sess, ‘my_test_model’,global_step=1000)

3.加载一个预训练的模型 
a)创建网络 
使用tf.train.import()函数加载以前保存的网络。 
saver = tf.train.import_meta_graph(‘my-model-1000.meta’) 
注意,import_meta_graph将保存在.meta文件中的图添加到当前的图中。所以,创建了一个图/网络,但是我们使用需要加载训练的参数到这个图中。

b)加载参数

'''restore tensor from model'''
w_out= self.graph.get_tensor_by_name('W:0')
b_out = self.graph.get_tensor_by_name('b:0')
_input = self.graph.get_tensor_by_name('x:0')
_out = self.graph.get_tensor_by_name('y:0')
y_pre_cls = self.graph.get_tensor_by_name('output:0')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

注意问题1: 
初始保存位置如果为e:,则这个位置被保存在checkpoint中 
修改后: 
model_checkpoint_path: “E:\tmp\newModel\crack_capcha.model-8100” 
all_model_checkpoint_paths: “E:\tmp\newModel\crack_capcha.model-8100”

这个过程形象的描述  
Technically, this is all you need to know to create a class-based neural network that defines the fit(X, Y) and predict(X) functions.

见stackoverFlow解释 
In( and After) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according tohttps://www.tensorflow.org/programmers_guide/meta_graph 
save model:

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

**# save method will call export_meta_graph implicitly. 
you will get saved graph files:my-model.meta** 
restore model:

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

一个完整的例子: 
self.session = tf.Session(graph=self.graph)

with self.graph.as_default():####默认图与自定义图的关系
    ckpt = tf.train.get_checkpoint_state(self.savefile)
       if ckpt and ckpt.model_checkpoint_path:
           print(''.join([ckpt.model_checkpoint_path,'.meta']))
           self.saver = tf.train.import_meta_graph(''.join([ckpt.model_checkpoint_path,'.meta']))
           self.saver.restore(self.session,ckpt.model_checkpoint_path)
       #print all variable
       for op in self.graph.get_operations():
       print(op.name, " " ,op.type)
       #返回模型中的tensor
       layers = [op.name for op in self.graph.get_operations() if op.type=='Conv2D' and 'import/' in op.name]
       layers = [op.name for op in self.graph.get_operations()]
       feature_nums = [int(self.graph.get_tensor_by_name(name+':0').get_shape()[-1]) for name in layers]
       for feature in feature_nums:
            print(feature)

     '''restore tensor from model'''
     w_out = self.graph.get_tensor_by_name('W:0')
     b_out = self.graph.get_tensor_by_name('b:0')
     _input = self.graph.get_tensor_by_name('x:0')
     _out = self.graph.get_tensor_by_name('y:0')
     y_pre_cls = self.graph.get_tensor_by_name('output:0')
     #self.session.run(tf.global_variables_initializer())   ####非常重要,不能添加这一句
        pred = self.session.run(y_pre_cls,feed_dict={_input:_X})
        return pred
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

中间有许多坑,但是成功的加载执行后,对模型的了解也加深了

猜你喜欢

转载自blog.csdn.net/weixin_38208741/article/details/80674284