PyTorch比较好的代码写法

1 修改模型网络结构后,如何载入旧有模型预训练的参数到修改后的模型

def load_checkpoint(model, checkpoint):
        # 修改后的模型的参数
        model_dict = model.state_dict()
        # 旧有模型结构的预训练网络模型,其中['state_dict']保存的模型参数
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 将训练好的参数update到model_dict当中
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)

        model.load_state_dict(model_dict)

2 过滤掉冻结不训练参数的code

optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5
        )

疑问:不过滤有影响吗?我觉得没有哎,反正也不更新,待验证...

猜你喜欢

转载自blog.csdn.net/Eric_Evil/article/details/106626354
今日推荐