迁移学习/fine-tuning

参考:


Pre-training model


完整程序:迁移学习


  • 迁移学习:
    只更新添加层的参数,冻结原始模型
    初始化参数时 只对对添加层的参数初始化

  • 模型微调:
    将原来的模型参数作为初始参数,也进行更新
    初始化参数时 只对对添加层的参数初始化


解析

加载模型

参考:facenet

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files) > 1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file
def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))

得到输入输出tensor

# Get input and output tensors
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0") # 输入 shape [batch_size,149,149,3]
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0") # 输出 shape [batch_size,128]
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0") # is_training

函数使用

参考:TensorFlow常用的函数

tf.stop_gradient() # 不更新梯度

tvars = tf.trainable_variables() # 获取所有可以更新的变量
d_params = [v for v in tvars if v.name.startswith('D/')]

tf.global_variables_initializer() # 默认初始化所以变量
tf.variables_initializer(var_list=[]) # 只初始化var_list中的变量

tf.global_variables() # 获取所以变量

ckpt模型做迁移

# 附加 ckpt模型做迁移思路
saver = tf.train.Saver(max_to_keep=1)  # 最多保留一个版本
var_list = tf.global_variables()
print(var_list)
var_list_1=[]
for var in var_list:  # 不加载 最后两层的参数,即重新训练
    if 'fc1' in var.name  or 'ouput' in var.name:
        # var_list_1.remove(var)
        continue
    var_list_1.append(var)
print(var_list_1)

saver = tf.train.Saver(max_to_keep=1,var_list=var_list_1)

saver.restore(sess,'save_path')
发布了96 篇原创文章 · 获赞 179 · 访问量 64万+

猜你喜欢

转载自blog.csdn.net/wc781708249/article/details/80051463