如何查看Tensoflow模型中已保存的参数

版权声明:转载请注明出处 https://blog.csdn.net/The_lastest/article/details/83957290

1.保存和读取

1.1 保存

import tensorflow as tf


aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(aa))
    # Step 1  保存
    saver.save(sess,'./ttt')
    
>>
[[ 0.8604646   0.45935377 -0.24135743 -2.2841513 ]
 [-0.20688622  0.60574555 -0.26031223 -0.441991  ]
 [-0.22254886  1.4805079  -1.7360271   1.1423918 ]]

这儿我们定义了一个name=var的变量(随便说一句aa这类名称是我们写程序时用以区分各个变量之间的依据,换句话说是给我们自己看的;而var这个名字是tensorflow计算图上用来区分各个变量和操作的依据),并且将其进行了保存。

1.2 读取

说到读取,就有两个方面了:第一,知道参数的名字(上面的var)时之间读取该变量;第二,不知道参数的名称时可以先打出所有变量,然后找你所要变量对应的名称再按名读取就行。

#----------------------直接按名读取---------------------------

import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
    print(sess.run(tf.get_default_graph().get_tensor_by_name('var:0')))
    
>>

[[ 0.8604646   0.45935377 -0.24135743 -2.2841513 ]
 [-0.20688622  0.60574555 -0.26031223 -0.441991  ]
 [-0.22254886  1.4805079  -1.7360271   1.1423918 ]]

#----------------------查看所有变量名---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:

    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    print(var_list)
    print(sess.run(var_list))
    
['var:0']
[array([[ 0.8604646 ,  0.45935377, -0.24135743, -2.2841513 ],
       [-0.20688622,  0.60574555, -0.26031223, -0.441991  ],
       [-0.22254886,  1.4805079 , -1.7360271 ,  1.1423918 ]],
      dtype=float32)]

可以看到读取变量后的输出值和保存时的一样。

2.哪些变量能够保存

其实saver.save()在保存参数的时候是有选择的(我说的选择不是通过save()参数里面控制的参数),看例子:

aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)

i = 10

saver = tf.train.Saver()
with tf.Session() as sess:
    # Step 1  保存
    sess.run(tf.global_variables_initializer())
    saver.save(sess,'./ttt')

这儿我们一共定义了6个参数,其中有三个tensor变量(aa,dd,ee)和两个tensor常量(bb,cc),和一个普通变量,我们来看一下哪些参数保存成功:

aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    # print(sess.run(list_before_train))
    print(var_list)

>>

['aa:0', 'cc_1:0', 'ee:0']

我们可以看到,这儿只有3个变量被保存成功,aa,ee,cc_1。明显,aa指得就是第1行代码定义得变量,ee指得就是第5行代码,那么cc_1指得是第3行还是第4行呢? 指得是第4行,这也印证tensorflow内部是通过name='var'这个参数来区分的。

由此我们可以得出:saver.save()只保存tensor变量,也就是tf.Variable()定义的变量,其它量包括tensor常量都是不被保存的。

扫描二维码关注公众号,回复: 4023414 查看本文章

3.网络模型的参数也能这样来保存么?

答案是:能!

这里以一个rnn cell按时间维度展开为例:

#--------------------------------------------保存--------------------------------------

import tensorflow as tf
import numpy as np

output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # Step 1  保存
    saver.save(sess,'./ttt')

#-------------------------------------------读取--------------------------------------


import tensorflow as tf
import numpy as np

output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    print(var_list)
    
>>
['rnn/basic_rnn_cell/kernel:0', 'rnn/basic_rnn_cell/bias:0']

既然都保存了,那为什么这儿只有两个变量呢?那是因为tensorflow内部在计算时为了方便或是更快,把所有的weight和bias都叠在一起了,具体参见此处!

另外说明一下:

在网上看到很多人提问LSTM训练好的模型“保存不了”。为什么会觉得保存不了呢? 因为在当训练到某个时候loss已经很低了,当stop后再次载入最新几个模型时都发现loss急剧升高,因此就会决定是因为模型的参数没有保存成功而导致的,因为本人在这两天也出现了这个问题。于是网上各种搜查LSTM模型保存的方法,试了一大堆依旧无效,后来终于发现是由于同一个函数在不同平台(windwo,linux)上的处理结果居然不一样,导致预处理后的训练集一直在变而导致的!

另外,你还可以通过在每次保存LSTM模型时,打印出其中某个参数的具体值,然后手动stop;当你再次载入模型时,立马输出同一个变量,对比一下是否相同,如果相同则说明保存成功。依照我自己的实验来看,两者是相同的。


print('------保存时的值----->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5, :4])
                      

print('载入时的值----------->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5,:4])


#   同一个变量,相同部分的值

猜你喜欢

转载自blog.csdn.net/The_lastest/article/details/83957290
今日推荐