pytorch Load部分weights

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/weixin_39610043/article/details/90032911

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。
因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

        pretrained_dict = torch.load(“model.pth”)
        model_dict = et.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)

猜你喜欢

转载自blog.csdn.net/weixin_39610043/article/details/90032911