1. Save the structure information and model parameter information of the entire neural network, the object of save is the network net
# 保存和加载整个模型
torch.save(model_object, 'resnet.pth')
model = torch.load('resnet.pth')
2. Only save the training model parameters of the neural network, the object of save is net.state_dict()
# 将my_resnet模型储存为my_resnet.pth
torch.save(my_resnet.state_dict(), "my_resnet.pth")
# 加载resnet,模型存放在my_resnet.pth
my_resnet.load_state_dict(torch.load("my_resnet.pth"))
Among them, my_resnet is the network structure corresponding to my_resnet.pth.
The first method is relatively more time-consuming.