计算机视觉——【tensorflow入门】模型的存储与加载

笔者作为深度学习路上的一枚小菜菜,最近查了一下如何保存模型和加载,这里做一下梳理:
模型训练中无可避免的需要做模型的保存和加载,以便于持久化数据和移植代码,这里就简单的介绍一下基于Tensorflow框架的模型保存与再载:

当保存模型的时候,会生成以下文件:
在这里插入图片描述
模型保存的数据分两部分,一部分是图(Graph)或者认为是一系列的运算(Ops);另一部分是数据,即各种变量在运行中的数值;其中.meta文件就是记录着Graph部分,而.index和.data-0000-of-00001后缀的则保存着数据部分;checkpoint为所保存的文件的最新记录,文本文件;

1 什么是Tensorflow model?

当我们训练好一个网络模型时,我们可能需要保存模型和数据甚至进一步部署到应用中。那么什么是Tensorflow 模型呢?Tensorflow模型主要包含两部分:1)网络结构设计或图结构(Graph). 2)训练好的或经过一定训练的网络参数。因此, Tensorflow model 主要有两个文件:

1. 1 Meta Graph

它是一个协议缓冲文件(a protocol buffer),用来保存完整的Tensorflow 图结构,包括所有的变量、操作、集合等,以.meta为文件后缀。

1.2 Checkpoint file

其为二进制文件,涵盖了所有的(在保存时,若没有特殊制定,则保存所有)参数(weights)、偏置(biases)、梯度值(gradients)和其他的变量。此文件以.ckpt为文件后缀。然而,从0.11版本后,Tensorflow就将其分裂成了两个文件:.data-0000-of-0001和.index,如:
在这里插入图片描述
.data文件之后的数字是变化的,0.11版本后,其存储了我们的训练参数。

而与.data .meta .index同时生成的还有一个checkpoint文本文件,其作用是为最近保存的模型文件做记录。

小结:

  1. TTensorflow-0.11版本之前的模型保存后, 文件有三个:
model.meta # Store Graph
model.ckpt # Store variables
checkpoint # Store record latest checkpoint file saved
  1. TTensorflow-0.11版本之前的模型保存后, 文件有四个:
model.meta # Stroe grap
model.data-* # Store variables
model.index
checkpoint # Store record latest checkpoint file saved

2 保存模型

假设说,你在训练一个图像分类的卷积神经网络。作为一个标准操作,需要观察模型的损失函数值和正确率,当模型收敛的时候,你可能想要停止训练或者只运行一定循环数的数据集。当训练结束的时候,通常保存所有的数据和网络图结构到文件就是下一步操作,在Tnesorflow中,是通过tf.train.Saver()类来进行保存模型等相关操作的。

tip:
需要注意的是,Tnesorflow中的变量值只在session运行期间存在,所以我们必须在会话生存期内进行变量的保存:

saver = tf.train.Saver()
with tf.Session() as sess:
	...
   saver.save(sess, 'my-test-model')`

这里,sess是一个session对象,'my-test-model’为想要赋予模型文件的名字,以下为完整代码:

import tensorflow as tf
w_1 = tf.Variables(tf.random_normal(shape[2], name='w_1'))
w_2 = tf.Variables(tf.random_normal(shape=[5]), name='w_2')
saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	saver.save(sess, 'my-test-model')
       
# 以上代码将保存以下文件:
"""
my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta
checkpoint
"""

如果想要在1000个循环后,保存文件,我们可以给出global_step参数:

saver.save(sess, ‘mt-test-model’, global_step=1000)

此时,保存的文件会变成:

saver.save(sess, 'mt-test-model', global_step=1000)
    
# files are saved as follows:
"""
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
"""

进一步说,我们没过1000循环都保存文件,所以.meta文件在第一次就被保存(第一个1000循环到达时),实际上我们不需要每到1000整数次循环就在此重写.meta文件,因为其在训练过程中很少改变,则我们可以加入writer_meta_graph=False:

saver.save(sess, 'my-test-model', global_step=step, write_meta_graph=False)

如果想在磁盘上每隔2小时保存最近的4个模型文件的话,可以使用max_to_keep和keep_checkpoint_every_n_hours:

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

tip:
如果没有特殊指定,tf.train.Saver()将保存所有变量。

若指向保存部分参数,则可以指定变量名或集合,当创建saver时,直接传入参数:

import tensorflow as tf
w_1 = tf.Variable(tf.random_normal(shape=[2]), name='w_1')
w_2 = tf.Variable(tf.random_normal(shape=[5]), name='w_2')
saver = tf.train.Saver([w_1,w_2])
with tf.Session() as sess:
     sess.run(tf.global_variables_initializer())
     saver.save(sess, 'my_test_model',global_step=1000)

3 加载模型文件

当需要使用别人训练好的模型时,需要有以下两个步骤:

tip:
模型和参数是分别加载的,如果仅使用tf.train.Saver.restore(sess, <ckpt_dir>)进行数据加载,则仅仅加载了数据部分,网络结构并没有加载;

3.1 加载网络结构 (Graph)

使用训练好的数据或权重时,可以通过两种方法对网络结构进行重建:一是自己手动再建网络(确保参数的名称、初始化维度和数据类型都要和训练时候的一模一样,不然加载权重数据时会出错);二是直接加载模型.meta文件,通过tf.train.import()类样的函数:

tf.train.import_meta_graph('my-test-model.meta')

谨记,import_meta_graph其作用是将.meta文件中预先定义好的图(Graph, 网络结构)添加到当前的图(Graph)中。所以,它会自动重建与.meta中一样的网络结构,但是此时还仅仅是一副’空结构‘,若想进行测试或者访问训练好的数据,我们仍需要加载数据文件

3.2 载入数据

我们可任意通过以下方法加载最近一次的数据文件:

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my-test-graph.meta')
    new_saver.restore(sess, tf.train.latest_checkpint('./'))

这时候,我们才完整的加载了网络(包括结构和保存模型时的数据),这时,可以对定义的w_1, w_2进行访问了,代码如下:

with tf.Session() as sess:    
     saver = tf.train.import_meta_graph('my-model-1000.meta')
     saver.restore(sess,tf.train.latest_checkpoint('./'))
     print(sess.run('w_1:0'))

截止到目前,已经完整的讲述了Tensorlfow模型保存和重现的基本步骤,下面就介绍一些更深入的使用。

4 加载模型文件后的更改

这部分介绍加载后模型的预测、微调甚至进一步训练。无论什么时候使用Tensorflow,都无法避免定义Graph,其包含了输入数据和一些超参数的设定,如学习率等,使用占位符(placeholders)来进行数据的输入是常用的操作。下面使用一以一个小的网络作为例子:

tip:
网络数据的保存,并不包括占位符(所以恢复模型后,若输入数据存在占位符,必须给出相同维度和数据类型的输入,模型才能正常运作);
它只是起到占用内存的作用,但是并没有数据,grap只知道这个操作的存在,但是数据是在人为启动程序并给出数据后才知道确切的数值。

注意这里输入是占位符,恢复模型后的输入也要有占位符的定义

import tensorflow as tf
    
# Prepare to feed input, i.e. feed_dict and placeholders
w_1 = tf.placeholder("float", name="w_1")
w_2 = tf.placeholder("float", name="w_2")
b_1= tf.Variable(2.0,name="bias")
feed_dict ={w_1:4,
			w_2:8}
    
# Define a test operation that we will restore
w_3 = tf.add(w_1,w_2)
w_4 = tf.multiply(w_3,b_1,name="op_to_restore")
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    # Create a saver object which will save all the variables
    saver = tf.train.Saver()
    
    # Run the operation by feeding input
    print sess.run(w4,feed_dict)
    # Prints 24 which is sum of (w_1+w_2)*b_1 
    
    # Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)

上述代码会保存模型文件,现在让我们进一步去恢复它,想要访问其中的数据,可以通过graph.get_tensor_by_name()方法;

# How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")
    
## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

若想进一步训练,则需要重新定义字典结构的输入数据:

import tensorflow as tf
    
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
    
    
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
    
graph = tf.get_default_graph()
w_1 = graph.get_tensor_by_name("w_1:0")
w_2 = graph.get_tensor_by_name("w_2:0")
feed_dict ={w1:13.0, w2:17.0} # New data
    
# Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    
print sess.run(op_to_restore,feed_dict)
# This will print 60 which is calculated 
# using new values of w1 and w2 and saved value of b1. 

如果想在重建后的网络加入一些新的层并进一步训练,也是可以的:

# Add new Ops
import tensorflow as tf
    
sess=tf.Session()    
# First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my-test-model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
    
    
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
    
graph = tf.get_default_graph()
w_1 = graph.get_tensor_by_name("w_1:0")
w_2 = graph.get_tensor_by_name("w_2:0")
feed_dict ={w_1:13.0,w_2:17.0}
    
# Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    
# Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
    
print sess.run(add_on_op,feed_dict)
# This will print 120.

加入新的层:

# Append new layer
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
    
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
    
# Use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
    
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
    
# Now, you run this with fine-tuning data in sess.run()

小结:

  1. 当模型的输入没有占位符时(极少情况),在恢复模型和模型参数时,可直接开启一个会话,run某个变量(此变量是原来模型中定义好的)来获取其值:
w_1 = tf.Variable(tf.random_normal(shape=[2]), name='w_1')
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
    saver.save(sess, os.path.join('my_test_model','','test'))
       
    
# Restore model
with tf.Session() as sess:
	new_saver = tf.train.import_meta_graph(os.path.join('my_test_model',
                                                            '','test.meta'))
	new_saver.restore(sess, os.path.join('my_test_model','','test'))   
	print('w_1:\t',sess.run('w_1:0'))

tip:

  1. 注意恢复后的模型,获取某个操作涵盖以下两个步骤:

    • graph=tf.get_default_graph()获取默认的图
    • graph.get_tensor_by_name(op_to_resotre)去获取恢复后模型结构中的某个操作;
      一定要注意,这里是get_tensor_by_name
  2. 条目1和还未保存时候的网络运行获取其图中操作的方式不同:

    • graph=sess.graph 获取当前会话的图
    • graph.get_operations() 获取图中所有操作
      graph.get_operation_by_name() 获取指定名称的操作
  1. 存在占位符形式的数据为输入时候,需要从恢复后的graph中获取占位符张量:
import os, sys, datetime
import numpy as np
import tensorflow as tf

# Store model with a input of placeholder
x_input_placeholer = tf.placeholder(dtype=tf.float32, name='x_input_placeholer')
x_var = tf.Variable(tf.truncated_normal([2,2]), dtype=tf.float32, name='x_var ')
y = tf.add(x_var , x_input_placeholer , name='y') # broadcast

saver_2 = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(y, feed_dict={x_input:3})
    saver_2.save(sess, os.path.join('my_test_model','','test-2'))

import os, sys, datetime
import numpy as np
import tensorflow as tf
# Re-store model
with tf.Session() as sess:
    meta_path = os.path.join('my-test-model','','test-2.meta')
    print('restore path:', meta_path) # out: restore path: my_test_model\test-2.meta
    
    new_saver = tf.train.import_meta_graph(meta_path)
    ckpt_path = os.path.join('my-test-model','')
    new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
    
    # OK
    print(sess.run('x_var :0')) # Could directly run for tensors NOT be placeholders
    
    graph = tf.get_default_graph()
    x_input = graph.get_tensor_by_name('x_input_placeholer:0') # placeholder
    
    print(sess.run('y:0', feed_dict={x_input:5}))

"""
out:
restore path: my_test_model\test-2.meta
[[-0.6829829  -0.6992631 ]
 [ 0.10211307 -0.5825841 ]]
[[4.317017  4.300737 ]
 [5.1021132 4.417416 ]]
"""

参考:

  1. A quick complete tutorial to save and restore Tensorflow models. 一个很好的网站,但是是英文帖子网站,内容并不是很系统,覆盖面有限,但是内容都很好,这里安利一下下~
  2. tensorflow线上手册
发布了47 篇原创文章 · 获赞 55 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/u011106767/article/details/96424315
今日推荐