[Tensorflow] A tool class used to process parameter names and matrix values in checkpoint

0x00 Preface

At present, for Tensorflow model parameter files, we are not as convenient as Pytorch's parameter files,
and now there is a requirement in the task to copy the parameters of a specific row to some other rows in a certain parameter matrix.
Pytorch is fine, because after all, it is a Python basic data structure in which a bunch of tensors are wrapped by an OrderDict.
The same thing is more troublesome to process in Tensorflow, so consider implementing this tool class CheckpointMonitor to improve processing efficiency.

0x01 Effect and API

  • Support to modify any parameter matrix from Tensorflow's model parameter file ckpt
    • The parameter names can be modified in batches or individually, keeping the attributes of the parameters unchanged
      • The method of batch modification is: allow a function to be passed in, and the input parameter name will be modified to the output parameter name according to the custom function
      • For example, this step is needed when Tensorflow and PyTorch parameters are mutually converted
    • You can save the modified parameters back to Tensorflow (Figure 1 below) or save them as PyTorch (Figure 2 below)
    • You can filter, check, and modify all or part of the values ​​of any parameter matrix. For tools, the whole process can be processed in numpy data format.
    • Automatically maintain the parameter sequence in the model file, or expand on the existing model parameters, such as parameter splicing

0x02 API list

  • Initial pass parameter __init__(checkpoint_path)as checkpoint path
  • list_variables() Show all the parameters in the current checkpoint, namely shape
  • list_target_variables(pattern)Similarly list_variables, display the filtered parameter list (Figure 3)
  • get_var_data(var_name) Get the parameter corresponding to the parameter name in the model file, the format is numpy
  • save_model(path, method='tf) Save the model file back to Tensorflow or Pytorch
  • modify_var_name(old_name, new_name) Modify parameter name
  • modify_var_names(rename_func) Modify parameter names in batches
  • modify_var_data(var_name, var_data) Modify the value of the parameter
  • These are the current ones, and may be added in the future (for example, encryption and decryption, model lightweight tools can be integrated into this category)

0x03 requirements

  • python >= 3.6 (lower version not tested)
  • tensorflow >= 1.15 (lower version not tested)
  • torch >= 1.4 (required if you need to save as torch)
  • numpy

0x04 Source Code

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = ""
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
from collections import OrderedDict


class CheckpointMonitor(object):
    """
    # CPU mode
    import os
    os.environ['CUDA_LAUNCH_BLOCKING'] = ""
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    """
    def __init__(self, checkpoint_path=None):
        if checkpoint_path is None:  # default path for testing
            checkpoint_path = '/data/sharedata/model_files/model.ckpt-250042'
        
        self.saver = None
        self.graph = None
        self.dump_path = './'
        self.checkpoint_path = checkpoint_path
        self.default_dump_name = 'my_modified_model'
        self.var_name_list = []
        self.var_shape_dict = OrderedDict()
        self.var_data_dict = OrderedDict()
        self.init_vars()
    
    def reload(self, checkpoint_path=None):
        self.__init__(checkpoint_path=checkpoint_path)
    
    def init_vars(self, checkpoint_path=None):
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_path
        self.var_shape_dict = OrderedDict(
            self.list_variables(checkpoint_path))
        self.var_name_list = list(self.var_shape_dict.keys())
        for var_name in self.var_name_list:
            # print(var_name)
            var_data = self.get_var_data(var_name, checkpoint_path)
            # dict(str, np.array)
            self.var_data_dict.update({
    
    var_name: var_data})
    
    def sort_var_dicts(self):
        self.var_data_dict = OrderedDict(
            [(var_name, self.var_data_dict[var_name]) 
             for var_name in self.var_name_list])
        self.var_shape_dict = OrderedDict(
            [(var_name, self.var_shape_dict[var_name]) 
             for var_name in self.var_name_list])
    
    def list_variables(self, checkpoint_path=None):
        # get all variables in form of tuple(name, shape) in checkpoint
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_path
        # return a list of (var_name, shape)
        return tf.contrib.framework.list_variables(checkpoint_path)
    
    def list_target_variables(self, pattern, checkpoint_path=None):
        if checkpoint_path is None:
            if self.var_shape_dict.__len__() != 0:
                # lazy loading
                var_list = self.var_shape_dict.items()
                return [(name, shape) for (name, shape) 
                        in var_list if pattern in name]
            else:  # load for cold-booting
                checkpoint_path = self.checkpoint_path
        var_list = self.list_variables(checkpoint_path)
        return [(name, shape) for (name, shape) in var_list if pattern in name]
    
    def get_var_data(self, var_name, checkpoint_path=None):
        # load variable from target checkpoint with the name as var_name
        if checkpoint_path is None:
            if self.var_data_dict.__len__() != 0:
                # lazy loading
                return self.var_data_dict.get(var_name)
            checkpoint_path = self.checkpoint_path
        # return the variable object (np.array)
        return tf.contrib.framework.load_variable(checkpoint_path, var_name)
    
    @staticmethod
    def generate_rename_func(old_name_list, new_name_list):
        def fn(var_name):
            if var_name in old_name_list:
                return new_name_list[old_name_list.index(var_name)]
            return var_name
        return fn
    
    def modify_var_name(self, old_name, new_name, inplace=True):
        var_index = self.var_name_list.index(old_name)
        self.var_name_list[var_index] = new_name
        self.var_data_dict[new_name] = self.var_data_dict[old_name]
        self.var_shape_dict[new_name] = self.var_shape_dict[old_name]
        del self.var_data_dict[old_name]
        del self.var_shape_dict[old_name]
        if inplace:
            self.sort_var_dicts()
    
    def modify_var_names(self, rename_func=None):
        # modify var_names in batch, with a feed function `rename_func`
        if rename_func is None:
            rename_func = lambda _name: _name

        with tf.Session() as sess:
            for var_index, var_name in enumerate(self.var_name_list): 
                # get variable values, in form of np.array
                new_name = rename_func(var_name)
                if new_name != var_name:
                    self.modify_var_name(var_index, new_name, inplace=False)
                    print('Re-naming {} to {}.'.format(var_name, new_name))
            self.sort_var_dicts()
    
    def modify_var_data(self, var_name, var_data):
        assert isinstance(var_data, np.ndarray)
        if var_name not in self.var_name_list:
            print("Invalid variable name:{}".format(var_name))
            print("You can get avaliable variable names by calling list_variables()")
        var_index = self.var_name_list.index(var_name)
        self.var_shape_dict[var_name] = list(var_data.shape)
        self.var_data_dict[var_name] = var_data
    
    def generate_var_dict_for_torch(self, var_list=None):
        if var_list is None:
            var_list = self.var_data_dict.items()
        torch_model_dict = OrderedDict()
        for var_name, var_data in var_list:
            var = torch.tensor(var_data)
            torch_model_dict.update({
    
    var_name: var})
        return torch_model_dict
    
    def generate_var_list_for_saver(self, var_list=None):
        if var_list is None:
            var_list = self.var_data_dict.items()
        saver_var_list = []
        with tf.Session() as sess:
            for var_name, var_data in var_list:
                var = tf.Variable(var_data, name=var_name)
                saver_var_list.append(var)
        return saver_var_list
    
    def save_model(self, new_checkpoint_path=None, model_name=None, method='pt'):
        if new_checkpoint_path is None:
            new_checkpoint_path = self.dump_path
        if not os.path.exists(new_checkpoint_path):
            os.makedirs(new_checkpoint_path)
        if model_name is None:
            model_name = self.default_dump_name
        checkpoint_path = os.path.join(
            new_checkpoint_path, model_name)
        
        method_dict = {
    
    
            'pt': self.save_model_as_pt,
            'tf': self.save_model_as_tf,
            'ckpt': self.save_model_as_tf,
            'torch': self.save_model_as_pt,
            'pytorch': self.save_model_as_pt,
            'tensorflow': self.save_model_as_tf,
        }
        method_dict[method](checkpoint_path)
    
    def save_model_as_pt(self, checkpoint_path):
        import torch
        var_dict = self.generate_var_dict_for_torch()
        checkpoint = OrderedDict({
    
    'model': var_dict})
        torch.save(checkpoint, checkpoint_path + '.pt')
        print("Checkpoint saving finished !\n{}".format(
            checkpoint_path + '.pt'))
    
    def save_model_as_tf(self, checkpoint_path):
        with tf.Session() as sess:
            var_list = self.generate_var_list_for_saver()
            # Construct the Saver
            self.saver = tf.train.Saver(var_list=var_list)
            # Necessary! Call the initializer at the beginning.
            sess.run(tf.global_variables_initializer())
            self.saver.save(sess, checkpoint_path)
            print("Checkpoint saving finished !\n{}".format(
                checkpoint_path))

0x05 effect display

Figure 1 Read the original TF model → modify the single value → save back → read the new TF model → check the modification

Figure 1 Read the original TF model → modify the single value → save back → read the new TF model → check the modification

Figure 2 Read the original TF model → modify the single value → save as a Pytorch model → read the new PyTorch model → check the modification

Figure 2 Read the original TF model → modify the single value → save as a Pytorch model → read the new PyTorch model → check the modification

Guess you like

Origin blog.csdn.net/okcd00/article/details/107127281