【模型参数】tensorflow1.x (slim) 和tensorflow2.x (keras) 的查看模型参数方式

tensorflow1.x (slim)

1. 查看模型所有参数 / 指定参数:

tf.trainable_variables()
sess.run(variable_names)

import tensorflow as tf  # v 1.13.2
print("the params in model:")
variable_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variable_names)
for k, v in zip(variable_names, values):
    print(k)  # 参数名
    print(v)  # 参数值

####################
# 如需查看特定变量名的值
v = [v for k,v in zip(variable_names, values) if k == 'DS-CNN/conv_1/batch_norm/beta'][0]
print(v)

2. 查看ckpt保存的所有模型参数 / 指定参数

pywrap_tensorflow.NewCheckpointReader(ckpt_file)

from tensorflow.python import pywrap_tensorflow
print("the params in ckpt:")
# tmp 文件夹下保存有:
# 1)ds_cnn.ckpt-13400000.index
# 2)ds_cnn.ckpt-13400000.meta
# 3)ds_cnn.ckpt-13400000.data-00000-of-00001
ckpt_file = '/tmp/ds_cnn.ckpt-13400000'  
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_file)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
    print(key)
    print(reader.get_tensor(key))

#################
# 如需打印特定参数:
print(reader.get_tensor('DS-CNN/conv_1/batch_norm/beta'))

tensorflow2.x (keras)

1. 查看模型所有参数 / 指定参数:

model.variables

# model 是已经创建好的keras模型
model_variables = model.variables
for v in model_variables:
   print(v.name)
   print(v.value)
   
####################
# 如需查看特定变量名的值
v = [v for v in model_variables if v.name == 'DS-CNN/conv_1/batch_norm/beta'][0]
print(v)
# 如需修改特定变量名的值,则在查找到后进行修改即可:
v.assign(1)

2. 查看ckpt保存的所有模型参数 / 指定参数

tf.train.load_checkpoint(ckpt_file)

import tensorflow as tf
ckpt_file = '/tmp/ds_cnn.ckpt-13400000'  
ckpt_var = tf.train.load_checkpoint(ckpt_file)
for key in ckpt_var:
    print(key)
    print(ckpt_var.get_tensor(key))
  
#################
# 如需打印特定参数:
print(ckpt_var.get_tensor('DS-CNN/conv_1/batch_norm/beta'))

猜你喜欢

转载自blog.csdn.net/u010637291/article/details/108143002
今日推荐