Tensor变量的存取(包括parameter)
对于普通Tensor变量的存取,如下代码所示:
import torch
import torch.nn as nn
x = torch.ones(3)
torch.save(x,'x.pt')
x2 = torch.load('x.pt')
print(x2)
读写模型参数
保存模型参数
torch.save(net.state_dict(),'model_param.pth')
载入模型参数
mynet = MLP()
mynet.load_state_dict(torch.load('model_param.pth'))
mynet.state_dict()
保存和读取整个模型
模型的保存
print(net(x))
torch.save(net,'model.pth')
模型载入
mynets = torch.load('model.pth')
mynets(x)