Pytorch基础知识

1. 保存和加载整个模型

(1) 保存和加载整个模型

torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

(2) 仅保存和加载模型参数(推荐使用)

torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

猜你喜欢

转载自blog.csdn.net/xuluhui123/article/details/80140964
今日推荐