修改tensorflow checkpoint内容 TensorFlow 输出checkpoint 中的变量名与变量值

https://blog.csdn.net/qq_33666011/article/details/80522564

看上了两篇文章~想留着哈哈哈~转载自

https://blog.csdn.net/qq_32799915/article/details/80312928

https://zhuanlan.zhihu.com/p/36982683

哈哈哈

TensorFlow 输出checkpoint 中的变量名与变量值
[python]  view plain  copy
  1. import os  
  2. from tensorflow.python import pywrap_tensorflow  
  3. model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置  
  4. # Read data from checkpoint file  
  5. reader = pywrap_tensorflow.NewCheckpointReader(model_dir)  
  6. var_to_shape_map = reader.get_variable_to_shape_map()  
  7. # Print tensor name and values  
  8. for key in var_to_shape_map:  
  9.     print("tensor_name: ", key)  #输出变量名  
  10.     print(reader.get_tensor(key))   #输出变量值  

输出结果:
这里只输出了变量名

Tensorflow修改已训练模型变量名字的方法

你是否有遇到以下几个场景,特别需要修改tensorflow已训练模型变量名字呢?

  1. 需要从预训练模型恢复权重,而使用框架不同导致某些层变量名字不一样,但基本的网络结构都可以一一对应的时候,如slim与tensorlayer;
  2. 转换模型框架,如使用某些工具转换tensorflow模型到caffe模型,因为某些变量名字与转换工具定义的BN层变量名字不一的时候;
  3. 想修改变量名字长度的时候;等等。

那么,我们该如何修改呢?

首先,常用的tensorflow已训练完成的模型有checkpoint和已固化为pb这两种格式,但是,由于使用了tf.contrib.framework.list_variables,目前暂且仅支持checkpoint格式;

扫描二维码关注公众号,回复: 3663059 查看本文章

其次,我们只希望更改变量名字,不希望动到图结构,所以,我们需要先恢复图模型只更改其中的变量名字;

最后,贴上我实验的代码及其运行效果!

# -*- coding:utf-8 -*-
#!/usr/bin/env python

‘’’
############################################################
rename tensorflow variable.
############################################################
‘’’

import tensorflow as tf
import argparse
import os
import re

def get_parser():
parser = argparse.ArgumentParser(description=‘parameters to rename tensorflow variable!’)
parser.add_argument(’–ckpt_path’, type=str, help=‘the ckpt file where to load.’)
parser.add_argument(’–save_path’, type=str, help=‘the ckpt file where to save.’)
parser.add_argument(’–rename_var_src’, type=str, help=""“Comma separated list of replace variable from”"")
parser.add_argument(’–rename_var_dst’, type=str, help=""“Comma separated list of replace variable to”"")
parser.add_argument(’–add_prefix’, type=str, help=‘prefix of newname.’)
args = parser.parse_args()
return args

def load_model(model_path, input_map=None):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model_path)
if (os.path.isfile(model_exp)):
print(‘not support: %s % model_exp)
else:
print(‘Model directory: %s % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)

    <span class="nb" style="color:rgb(0,132,255);">print</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'Metagraph file: </span><span class="si" style="color:rgb(241,64,60);">%s</span><span class="s1" style="color:rgb(241,64,60);">'</span> <span class="o" style="font-weight:600;">%</span> <span class="n">meta_file</span><span class="p">)</span>
    <span class="nb" style="color:rgb(0,132,255);">print</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'Checkpoint file: </span><span class="si" style="color:rgb(241,64,60);">%s</span><span class="s1" style="color:rgb(241,64,60);">'</span> <span class="o" style="font-weight:600;">%</span> <span class="n">ckpt_file</span><span class="p">)</span>

    <span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">train</span><span class="o" style="font-weight:600;">.</span><span class="n">import_meta_graph</span><span class="p">(</span><span class="n">os</span><span class="o" style="font-weight:600;">.</span><span class="n">path</span><span class="o" style="font-weight:600;">.</span><span class="n">join</span><span class="p">(</span><span class="n">model_exp</span><span class="p">,</span> <span class="n">meta_file</span><span class="p">),</span> <span class="n">input_map</span><span class="o" style="font-weight:600;">=</span><span class="n">input_map</span><span class="p">)</span>
    <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">restore</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">get_default_session</span><span class="p">(),</span> <span class="n">os</span><span class="o" style="font-weight:600;">.</span><span class="n">path</span><span class="o" style="font-weight:600;">.</span><span class="n">join</span><span class="p">(</span><span class="n">model_exp</span><span class="p">,</span> <span class="n">ckpt_file</span><span class="p">))</span>

<span class="k" style="font-weight:600;">return</span> <span class="n">saver</span>

def get_model_filenames(model_dir):
files = os.listdir(model_dir)
meta_files = [s for s in files if s.endswith(’.meta’)]
if len(meta_files) == 0:
raise ValueError(‘No meta file found in the model directory (%s)’ % model_dir)
elif len(meta_files) > 1:
raise ValueError(‘There should not be more than one meta file in the model directory (%s)’ % model_dir)
meta_file = meta_files[0]
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
return meta_file, ckpt_file

<span class="n">meta_files</span> <span class="o" style="font-weight:600;">=</span> <span class="p">[</span><span class="n">s</span> <span class="k" style="font-weight:600;">for</span> <span class="n">s</span> <span class="ow" style="font-weight:600;">in</span> <span class="n">files</span> <span class="k" style="font-weight:600;">if</span> <span class="s1" style="color:rgb(241,64,60);">'.ckpt'</span> <span class="ow" style="font-weight:600;">in</span> <span class="n">s</span><span class="p">]</span>
<span class="n">max_step</span> <span class="o" style="font-weight:600;">=</span> <span class="o" style="font-weight:600;">-</span><span class="mi" style="color:rgb(0,132,255);">1</span>
<span class="k" style="font-weight:600;">for</span> <span class="n">f</span> <span class="ow" style="font-weight:600;">in</span> <span class="n">files</span><span class="p">:</span>
    <span class="n">step_str</span> <span class="o" style="font-weight:600;">=</span> <span class="n">re</span><span class="o" style="font-weight:600;">.</span><span class="n">match</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">r'(^model-[\w\- ]+.ckpt-(\d+))'</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
    <span class="k" style="font-weight:600;">if</span> <span class="n">step_str</span> <span class="ow" style="font-weight:600;">is</span> <span class="ow" style="font-weight:600;">not</span> <span class="kc" style="font-weight:600;">None</span> <span class="ow" style="font-weight:600;">and</span> <span class="nb" style="color:rgb(0,132,255);">len</span><span class="p">(</span><span class="n">step_str</span><span class="o" style="font-weight:600;">.</span><span class="n">groups</span><span class="p">())</span> <span class="o" style="font-weight:600;">&gt;=</span> <span class="mi" style="color:rgb(0,132,255);">2</span><span class="p">:</span>
        <span class="n">step</span> <span class="o" style="font-weight:600;">=</span> <span class="nb" style="color:rgb(0,132,255);">int</span><span class="p">(</span><span class="n">step_str</span><span class="o" style="font-weight:600;">.</span><span class="n">groups</span><span class="p">()[</span><span class="mi" style="color:rgb(0,132,255);">1</span><span class="p">])</span>
        <span class="k" style="font-weight:600;">if</span> <span class="n">step</span> <span class="o" style="font-weight:600;">&gt;</span> <span class="n">max_step</span><span class="p">:</span>
            <span class="n">max_step</span> <span class="o" style="font-weight:600;">=</span> <span class="n">step</span>
            <span class="n">ckpt_file</span> <span class="o" style="font-weight:600;">=</span> <span class="n">step_str</span><span class="o" style="font-weight:600;">.</span><span class="n">groups</span><span class="p">()[</span><span class="mi" style="color:rgb(0,132,255);">0</span><span class="p">]</span>
<span class="k" style="font-weight:600;">return</span> <span class="n">meta_file</span><span class="p">,</span> <span class="n">ckpt_file</span>

def rename(args):
‘’‘rename tensorflow variable, just for checkpoint file format.’’’

<span class="n">replace_from</span> <span class="o" style="font-weight:600;">=</span> <span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">rename_var_src</span><span class="o" style="font-weight:600;">.</span><span class="n">strip</span><span class="p">()</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">','</span><span class="p">)</span>
<span class="n">replace_to</span> <span class="o" style="font-weight:600;">=</span> <span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">rename_var_dst</span><span class="o" style="font-weight:600;">.</span><span class="n">strip</span><span class="p">()</span><span class="o" style="font-weight:600;">.</span><span class="n">split</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">','</span><span class="p">)</span>

<span class="k" style="font-weight:600;">assert</span> <span class="nb" style="color:rgb(0,132,255);">len</span><span class="p">(</span><span class="n">replace_from</span><span class="p">)</span> <span class="o" style="font-weight:600;">==</span> <span class="nb" style="color:rgb(0,132,255);">len</span><span class="p">(</span><span class="n">replace_to</span><span class="p">)</span>

<span class="k" style="font-weight:600;">with</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Session</span><span class="p">()</span> <span class="k" style="font-weight:600;">as</span> <span class="n">sess</span><span class="p">:</span>
    <span class="k" style="font-weight:600;">for</span> <span class="n">var_name</span><span class="p">,</span> <span class="n">_</span> <span class="ow" style="font-weight:600;">in</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">contrib</span><span class="o" style="font-weight:600;">.</span><span class="n">framework</span><span class="o" style="font-weight:600;">.</span><span class="n">list_variables</span><span class="p">(</span><span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">ckpt_path</span><span class="p">):</span>
        <span class="c1" style="font-style:italic;color:rgb(153,153,153);"># Load the variable</span>
        <span class="n">var</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">contrib</span><span class="o" style="font-weight:600;">.</span><span class="n">framework</span><span class="o" style="font-weight:600;">.</span><span class="n">load_variable</span><span class="p">(</span><span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">var_name</span><span class="p">)</span>

        <span class="c1" style="font-style:italic;color:rgb(153,153,153);"># Set the new name</span>
        <span class="n">new_name</span> <span class="o" style="font-weight:600;">=</span> <span class="n">var_name</span>

        <span class="k" style="font-weight:600;">for</span> <span class="n">index</span> <span class="ow" style="font-weight:600;">in</span> <span class="nb" style="color:rgb(0,132,255);">range</span><span class="p">(</span><span class="nb" style="color:rgb(0,132,255);">len</span><span class="p">(</span><span class="n">replace_from</span><span class="p">)):</span>
            <span class="n">new_name</span> <span class="o" style="font-weight:600;">=</span> <span class="n">new_name</span><span class="o" style="font-weight:600;">.</span><span class="n">replace</span><span class="p">(</span><span class="n">replace_from</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">replace_to</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>

        <span class="k" style="font-weight:600;">if</span> <span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">add_prefix</span><span class="p">:</span>
            <span class="n">new_name</span> <span class="o" style="font-weight:600;">=</span> <span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">add_prefix</span> <span class="o" style="font-weight:600;">+</span> <span class="n">new_name</span>

        <span class="nb" style="color:rgb(0,132,255);">print</span><span class="p">(</span><span class="s1" style="color:rgb(241,64,60);">'Renaming </span><span class="si" style="color:rgb(241,64,60);">%s</span><span class="s1" style="color:rgb(241,64,60);"> to </span><span class="si" style="color:rgb(241,64,60);">%s</span><span class="s1" style="color:rgb(241,64,60);">.'</span> <span class="o" style="font-weight:600;">%</span> <span class="p">(</span><span class="n">var_name</span><span class="p">,</span> <span class="n">new_name</span><span class="p">))</span>
        <span class="c1" style="font-style:italic;color:rgb(153,153,153);"># Rename the variable</span>
        <span class="n">var</span> <span class="o" style="font-weight:600;">=</span> <span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">var</span><span class="p">,</span> <span class="n">name</span><span class="o" style="font-weight:600;">=</span><span class="n">new_name</span><span class="p">)</span>

    <span class="c1" style="font-style:italic;color:rgb(153,153,153);"># Save the variables</span>
    <span class="n">saver</span> <span class="o" style="font-weight:600;">=</span> <span class="n">load_model</span><span class="p">(</span><span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">ckpt_path</span><span class="p">)</span>
    <span class="n">sess</span><span class="o" style="font-weight:600;">.</span><span class="n">run</span><span class="p">(</span><span class="n">tf</span><span class="o" style="font-weight:600;">.</span><span class="n">global_variables_initializer</span><span class="p">())</span>
    <span class="n">saver</span><span class="o" style="font-weight:600;">.</span><span class="n">save</span><span class="p">(</span><span class="n">sess</span><span class="p">,</span> <span class="n">args</span><span class="o" style="font-weight:600;">.</span><span class="n">save_path</span><span class="p">)</span>

if name == main:
args = get_parser()
rename(args)

将代码复制到文件rename_tf_variable.py,然后参照如下命令行执行脚本:

python rename_tf_variable.py --ckpt_path ~/github/MobileFaceNet_local/output/ckpt/ --save_path /home/xsr-ai/scripts/ckpt --rename_var_src gamma,moving_mean,moving_variance --rename_var_dst scale,mean,variance

这个命令是想将batch norm层的“gamma,moving_mean,moving_variance”改为“scale,mean,variance”,执行效果如下:





				<script>
					(function(){
						function setArticleH(btnReadmore,posi){
							var winH = $(window).height();
							var articleBox = $("div.article_content");
							var artH = articleBox.height();
							if(artH > winH*posi){
								articleBox.css({
									'height':winH*posi+'px',
									'overflow':'hidden'
								})
								btnReadmore.click(function(){
									articleBox.removeAttr("style");
									$(this).parent().remove();
								})
							}else{
								btnReadmore.parent().remove();
							}
						}
						var btnReadmore = $("#btn-readmore");
						if(btnReadmore.length>0){
							if(currentUserName){
								setArticleH(btnReadmore,3);
							}else{
								setArticleH(btnReadmore,1.2);
							}
						}
					})()
				</script>
				</article>

看上了两篇文章~想留着哈哈哈~转载自

https://blog.csdn.net/qq_32799915/article/details/80312928

https://zhuanlan.zhihu.com/p/36982683

哈哈哈

TensorFlow 输出checkpoint 中的变量名与变量值
[python]  view plain  copy
  1. import os  
  2. from tensorflow.python import pywrap_tensorflow  
  3. model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置  
  4. # Read data from checkpoint file  
  5. reader = pywrap_tensorflow.NewCheckpointReader(model_dir)  
  6. var_to_shape_map = reader.get_variable_to_shape_map()  
  7. # Print tensor name and values  
  8. for key in var_to_shape_map:  
  9.     print("tensor_name: ", key)  #输出变量名  
  10.     print(reader.get_tensor(key))   #输出变量值  

输出结果:
这里只输出了变量名

Tensorflow修改已训练模型变量名字的方法

你是否有遇到以下几个场景,特别需要修改tensorflow已训练模型变量名字呢?

  1. 需要从预训练模型恢复权重,而使用框架不同导致某些层变量名字不一样,但基本的网络结构都可以一一对应的时候,如slim与tensorlayer;
  2. 转换模型框架,如使用某些工具转换tensorflow模型到caffe模型,因为某些变量名字与转换工具定义的BN层变量名字不一的时候;
  3. 想修改变量名字长度的时候;等等。

那么,我们该如何修改呢?

首先,常用的tensorflow已训练完成的模型有checkpoint和已固化为pb这两种格式,但是,由于使用了tf.contrib.framework.list_variables,目前暂且仅支持checkpoint格式;

其次,我们只希望更改变量名字,不希望动到图结构,所以,我们需要先恢复图模型只更改其中的变量名字;

最后,贴上我实验的代码及其运行效果!

# -*- coding:utf-8 -*-
#!/usr/bin/env python

猜你喜欢

转载自blog.csdn.net/Jason_mmt/article/details/83116156