tensorflow获取ckpt的参数

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

方法1, 采用print_tensors_in_checkpoint_file 输出全部的参数

 


model_dir = "tfPara/"
 
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
print_tensors_in_checkpoint_file(ckpt_path, all_tensors=True,all_tensor_names = True, tensor_name='')


tensor_name:  beta1_power
0.3138105
tensor_name:  beta2_power
0.989055
tensor_name:  l1/bias
[ 0.02627242  0.05014378 -0.04511173  0.        ...
tensor_name:  l1/bias/Adam
[-3.4903574e+00 -2.6453564e+01  6.5189457e+00 ...
tensor_name:  l1/bias/Adam_1
[3.0681181e-01 1.6404903e+01 1.7573323e+00 ...
tensor_name:  l1/kernel
[[-0.05044619  0.07713371 -0.3241992   0.02499923  0.04578751 -0.2615546
   0.16445747 -0.34566906  0.2644527  -0.24029821 -0.1680427  -0.31824082
   0.2339457  -0.13366055  0.30801657  0.3266713  -0.3441043  -0.16687115
   0.06713113  0.22774872] ...

方法2 输出指定层的指定参数

其中,LAYER_1_NAME = 'l1’就是在用 tf.layers.dense() 创建网络的一个layer的时候(l1 = tf.layers.dense(tf_x, 20, tf.nn.relu, name =‘l1’))

这里的kernel 指的是权重参数,bias指的是bias

参考链接:
https://www.codelast.com/%e5%8e%9f%e5%88%9b-%e5%a6%82%e4%bd%95%e5%8f%96%e5%87%ba-tf-layers-dense-%e5%ae%9a%e4%b9%89%e7%9a%84%e5%85%a8%e8%bf%9e%e6%8e%a5%e5%b1%82%e7%9a%84weight%e5%92%8cbias%e5%8f%82%e6%95%b0%e5%80%bc/


model_dir = "tfPara/"
 
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
reader = tf.train.NewCheckpointReader(ckpt_path)

file = 'tfPara/para.txt'
layer_names = ['l1','l2','l3']
with open(file, 'w') as f:
    
    for ln in layer_names:
        weights = reader.get_tensor(ln + '/kernel')  # weight的名字,是由对应层的名字,加上默认的"kernel"组成的
        bias = reader.get_tensor(ln + '/bias')  # bias的名字
        print(weights)
        print(bias)
        f.write(ln + '/weights;   shape:'+str(weights.shape) + '\n')
        f.write(str(weights.tolist()))
        f.write('\n')
        f.write(ln + '/bias;   shape:'+str(bias.shape)+ '\n')
        f.write(str(bias.tolist()))
        f.write('\n')
        
[[-0.05044619  0.07713371 -0.3241992   0.02499923  0.04578751 -0.2615546
   0.16445747 -0.34566906  0.2644527  -0.24029821 -0.1680427  -0.31824082
   0.2339457  -0.13366055  0.30801657  0.3266713  -0.3441043  -0.16687115
   0.06713113  0.22774872] ...

[[ 0.22894777]
 [ 0.49630672] ...
发布了36 篇原创文章 · 获赞 0 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/weixin_38102912/article/details/101753418