1.保存和读取
1.1 保存
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(aa))
# Step 1 保存
saver.save(sess,'./ttt')
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ]
[-0.20688622 0.60574555 -0.26031223 -0.441991 ]
[-0.22254886 1.4805079 -1.7360271 1.1423918 ]]
这儿我们定义了一个name=var
的变量(随便说一句aa
这类名称是我们写程序时用以区分各个变量之间的依据,换句话说是给我们自己看的;而var
这个名字是tensorflow计算图上用来区分各个变量和操作的依据),并且将其进行了保存。
1.2 读取
说到读取,就有两个方面了:第一,知道参数的名字(上面的var)时之间读取该变量;第二,不知道参数的名称时可以先打出所有变量,然后找你所要变量对应的名称再按名读取就行。
#----------------------直接按名读取---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
print(sess.run(tf.get_default_graph().get_tensor_by_name('var:0')))
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ]
[-0.20688622 0.60574555 -0.26031223 -0.441991 ]
[-0.22254886 1.4805079 -1.7360271 1.1423918 ]]
#----------------------查看所有变量名---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
print(var_list)
print(sess.run(var_list))
['var:0']
[array([[ 0.8604646 , 0.45935377, -0.24135743, -2.2841513 ],
[-0.20688622, 0.60574555, -0.26031223, -0.441991 ],
[-0.22254886, 1.4805079 , -1.7360271 , 1.1423918 ]],
dtype=float32)]
可以看到读取变量后的输出值和保存时的一样。
2.哪些变量能够保存
其实saver.save()
在保存参数的时候是有选择的(我说的选择不是通过save()参数里面控制的参数),看例子:
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10
saver = tf.train.Saver()
with tf.Session() as sess:
# Step 1 保存
sess.run(tf.global_variables_initializer())
saver.save(sess,'./ttt')
这儿我们一共定义了6个参数,其中有三个tensor变量(aa,dd,ee)和两个tensor常量(bb,cc),和一个普通变量,我们来看一下哪些参数保存成功:
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
# print(sess.run(list_before_train))
print(var_list)
>>
['aa:0', 'cc_1:0', 'ee:0']
我们可以看到,这儿只有3个变量被保存成功,aa,ee,cc_1
。明显,aa指得就是第1行代码定义得变量,ee指得就是第5行代码,那么cc_1指得是第3行还是第4行呢? 指得是第4行,这也印证tensorflow内部是通过name='var'
这个参数来区分的。
由此我们可以得出:saver.save()
只保存tensor变量,也就是tf.Variable()
定义的变量,其它量包括tensor常量都是不被保存的。
3.网络模型的参数也能这样来保存么?
答案是:能!
这里以一个rnn cell按时间维度展开为例:
#--------------------------------------------保存--------------------------------------
import tensorflow as tf
import numpy as np
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x1
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x2
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]]) # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Step 1 保存
saver.save(sess,'./ttt')
#-------------------------------------------读取--------------------------------------
import tensorflow as tf
import numpy as np
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x1
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x2
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]]) # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
print(var_list)
>>
['rnn/basic_rnn_cell/kernel:0', 'rnn/basic_rnn_cell/bias:0']
既然都保存了,那为什么这儿只有两个变量呢?那是因为tensorflow内部在计算时为了方便或是更快,把所有的weight和bias都叠在一起了,具体参见此处!
另外说明一下:
在网上看到很多人提问LSTM训练好的模型“保存不了”。为什么会觉得保存不了呢? 因为在当训练到某个时候loss已经很低了,当stop后再次载入最新几个模型时都发现loss急剧升高,因此就会决定是因为模型的参数没有保存成功而导致的,因为本人在这两天也出现了这个问题。于是网上各种搜查LSTM模型保存的方法,试了一大堆依旧无效,后来终于发现是由于同一个函数在不同平台(windwo,linux)上的处理结果居然不一样,导致预处理后的训练集一直在变而导致的!
另外,你还可以通过在每次保存LSTM模型时,打印出其中某个参数的具体值,然后手动stop;当你再次载入模型时,立马输出同一个变量,对比一下是否相同,如果相同则说明保存成功。依照我自己的实验来看,两者是相同的。
print('------保存时的值----->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5, :4])
print('载入时的值----------->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5,:4])
# 同一个变量,相同部分的值