文章目录
tensorflow1.x (slim)
1. 查看模型所有参数 / 指定参数:
tf.trainable_variables()
sess.run(variable_names)
import tensorflow as tf # v 1.13.2
print("the params in model:")
variable_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variable_names)
for k, v in zip(variable_names, values):
print(k) # 参数名
print(v) # 参数值
####################
# 如需查看特定变量名的值
v = [v for k,v in zip(variable_names, values) if k == 'DS-CNN/conv_1/batch_norm/beta'][0]
print(v)
2. 查看ckpt保存的所有模型参数 / 指定参数
pywrap_tensorflow.NewCheckpointReader(ckpt_file)
from tensorflow.python import pywrap_tensorflow
print("the params in ckpt:")
# tmp 文件夹下保存有:
# 1)ds_cnn.ckpt-13400000.index
# 2)ds_cnn.ckpt-13400000.meta
# 3)ds_cnn.ckpt-13400000.data-00000-of-00001
ckpt_file = '/tmp/ds_cnn.ckpt-13400000'
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_file)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print(key)
print(reader.get_tensor(key))
#################
# 如需打印特定参数:
print(reader.get_tensor('DS-CNN/conv_1/batch_norm/beta'))
tensorflow2.x (keras)
1. 查看模型所有参数 / 指定参数:
model.variables
# model 是已经创建好的keras模型
model_variables = model.variables
for v in model_variables:
print(v.name)
print(v.value)
####################
# 如需查看特定变量名的值
v = [v for v in model_variables if v.name == 'DS-CNN/conv_1/batch_norm/beta'][0]
print(v)
# 如需修改特定变量名的值,则在查找到后进行修改即可:
v.assign(1)
2. 查看ckpt保存的所有模型参数 / 指定参数
tf.train.load_checkpoint(ckpt_file)
import tensorflow as tf
ckpt_file = '/tmp/ds_cnn.ckpt-13400000'
ckpt_var = tf.train.load_checkpoint(ckpt_file)
for key in ckpt_var:
print(key)
print(ckpt_var.get_tensor(key))
#################
# 如需打印特定参数:
print(ckpt_var.get_tensor('DS-CNN/conv_1/batch_norm/beta'))