加载训练的模型参数并继续训练

参考连接:

https://blog.csdn.net/hungryof/article/details/81364487

保存模型:

torch.save(model.state_dict(), model_path)

加载模型时一般用

 model.load_state_dict(torch.load(model_path))

其中,model_path 为模型路径。

值得注意的是:torch.load 返回的是一个 OrderedDict.

但是可能这样加载模型继续训练时,会出现一些问题,故可以改为:

model.load_state_dict(torch.load(model_path), strict=False)

pytorch官网:

https://pytorch.org/tutorials/beginner/saving_loading_models.html

感觉写的蛮详细的博客:

https://www.jianshu.com/p/1cd6333128a1

猜你喜欢

转载自www.cnblogs.com/Bella2017/p/11908822.html