重构预训练模型后加载参数

# 加载相同state_dict
def load_same_state_dict(myself_model: nn.Module, pretrain_model: nn.Module):
    pretrain_dict = pretrain_model.state_dict()
    myself_dict = myself_model.state_dict()

    # 当模型中的某层是同时在两个模型中共有时才取出
    pretrain_dict = {
    
    k: v for k, v in pretrain_dict.items() if k in myself_dict}
    
    
    myself_dict.update(pretrain_dict)
    # 然后将参数导入你的模型即可
    myself_model.load_state_dict(myself_dict)
    return myself_model

猜你喜欢

转载自blog.csdn.net/q506610466/article/details/123778020