Pytorch中参数和模型的保存与读取

Tensor变量的存取(包括parameter)

对于普通Tensor变量的存取,如下代码所示:

import torch
import torch.nn as nn
x = torch.ones(3)
torch.save(x,'x.pt')
x2 = torch.load('x.pt')
print(x2)

读写模型参数

保存模型参数

torch.save(net.state_dict(),'model_param.pth')

载入模型参数

mynet = MLP()
mynet.load_state_dict(torch.load('model_param.pth'))
mynet.state_dict()

保存和读取整个模型

模型的保存

print(net(x))
torch.save(net,'model.pth')

模型载入

mynets = torch.load('model.pth')
mynets(x)

猜你喜欢

转载自blog.csdn.net/qq_25105061/article/details/116268683