pytorch:保存和加载模型

之前跑了好几个模型,都需要保存和加载模型信息。为了省得每次都上网找,开一个帖记录一下。

1. 保存模型

此处只保存模型参数

model = torch.nn.Linear(1, 2)
torch.save(model.state_dict(), "./model.pth")

2. 加载模型

model = torch.nn.Linear(1, 2)
model.load_state_dict(torch.load("./model.pth"))

猜你喜欢

转载自blog.csdn.net/weixin_43466027/article/details/121363300