本篇文章主要介绍了python使用tensorflow保存、加载和使用模型的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我对这篇文章进行了整理和汇总。
首先是模型的保存。直接上代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 11:04:25
############################
import
tensorflow as tf
# prepare to feed input, i.e. feed_dict and placeholders
w1
=
tf.Variable(tf.random_normal(shape
=
[
2
]), name
=
'w1'
)
# name is very important in restoration
w2
=
tf.Variable(tf.random_normal(shape
=
[
2
]), name
=
'w2'
)
b1
=
tf.Variable(
2.0
, name
=
'bias1'
)
feed_dict
=
{w1:[
10
,
3
], w2:[
5
,
5
]}
# define a test operation that will be restored
w3
=
tf.add(w1, w2)
# without name, w3 will not be stored
w4
=
tf.multiply(w3, b1, name
=
"op_to_restore"
)
#saver = tf.train.Saver()
saver
=
tf.train.Saver(max_to_keep
=
4
, keep_checkpoint_every_n_hours
=
1
)
sess
=
tf.Session()
sess.run(tf.global_variables_initializer())
print
sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess,
'my_test_model'
)
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
|
需要说明的有以下几点:
1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。
2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。
3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。
下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 14:16:38
############################
import
tensorflow as tf
sess
=
tf.Session()
new_saver
=
tf.train.import_meta_graph(
'my_test_model.meta'
)
new_saver.restore(sess, tf.train.latest_checkpoint(
'./'
))
print
sess.run(
'w1:0'
)
|
使用加载的模型,输入新数据,计算输出,还是直接上代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 14:33:35
############################
import
tensorflow as tf
sess
=
tf.Session()
# First, load meta graph and restore weights
saver
=
tf.train.import_meta_graph(
'my_test_model.meta'
)
saver.restore(sess, tf.train.latest_checkpoint(
'./'
))
# Second, access and create placeholders variables and create feed_dict to feed new data
graph
=
tf.get_default_graph()
w1
=
graph.get_tensor_by_name(
'w1:0'
)
w2
=
graph.get_tensor_by_name(
'w2:0'
)
feed_dict
=
{w1:[
-
1
,
1
], w2:[
4
,
6
]}
# Access the op that want to run
op_to_restore
=
graph.get_tensor_by_name(
'op_to_restore:0'
)
print
sess.run(op_to_restore, feed_dict)
# ouotput: [6. 14.]
|
在已经加载的网络后继续加入新的网络层:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
import
tensorflow as tf
sess
=
tf.Session()
#First let's load meta graph and restore weights
saver
=
tf.train.import_meta_graph(
'my_test_model-1000.meta'
)
saver.restore(sess,tf.train.latest_checkpoint(
'./'
))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph
=
tf.get_default_graph()
w1
=
graph.get_tensor_by_name(
"w1:0"
)
w2
=
graph.get_tensor_by_name(
"w2:0"
)
feed_dict
=
{w1:
13.0
,w2:
17.0
}
#Now, access the op that you want to run.
op_to_restore
=
graph.get_tensor_by_name(
"op_to_restore:0"
)
#Add more to the current graph
add_on_op
=
tf.multiply(op_to_restore,
2
)
print
sess.run(add_on_op,feed_dict)
#This will print 120.
|
对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
......
......
saver
=
tf.train.import_meta_graph(
'vgg.meta'
)
# Access the graph
graph
=
tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7
=
graph.get_tensor_by_name(
'fc7:0'
)
#use this if you only want to change gradients of the last layer
fc7
=
tf.stop_gradient(fc7)
# It's an identity function
fc7_shape
=
fc7.get_shape().as_list()
new_outputs
=
2
weights
=
tf.Variable(tf.truncated_normal([fc7_shape[
3
], num_outputs], stddev
=
0.05
))
biases
=
tf.Variable(tf.constant(
0.05
, shape
=
[num_outputs]))
output
=
tf.matmul(fc7, weights)
+
biases
pred
=
tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()
|
有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。