Pytorch saves and loads the model in two ways to learn

1. Save the structure information and model parameter information of the entire neural network, the object of save is the network net

# 保存和加载整个模型
torch.save(model_object, 'resnet.pth')
model = torch.load('resnet.pth')

2. Only save the training model parameters of the neural network, the object of save is net.state_dict()

# 将my_resnet模型储存为my_resnet.pth
torch.save(my_resnet.state_dict(), "my_resnet.pth")
# 加载resnet,模型存放在my_resnet.pth
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

Among them, my_resnet is the network structure corresponding to my_resnet.pth.
The first method is relatively more time-consuming.

Guess you like

Origin blog.csdn.net/qq_41872271/article/details/105367290