Pytorch常用函数

一、模型的保存与加载
实现训练过程中模型的保存,以及在预训练的基础上继续训练模型
①保存和加载整个模型

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

②只保存模型中的参数

# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

猜你喜欢

转载自blog.csdn.net/qq_37053885/article/details/81782819