[Pytorch] neural network model saving and loading

[Pytorch] neural network model saving and loading

There are two ways to save and load the neural network model. One method saves and loads the entire network (including the graph structure of the network), and the other method only saves and loads the parameter names and values ​​of the network (excluding the graph structure of the network).

method 1

Save and load the entire network, including the graph structure of the network. code show as below:

# 保存网络
torch.save(model, PATH) # model是想要保存的模型,PATH是保存下来的文件的路径
# 加载网络
model2 = torch.load(PATH)

Method 2

Save and load the parameter names and values ​​of the network, excluding the graph structure of the network. code show as below:

# 保存网络
torch.save(model.state_dict(), PATH) # model.state_dict()是想要保存的模型的参数名及值,PATH是保存下来的文件的路径
# 加载网络
model2 = MyModel() # 初始化准备载入参数的新模型
state_dict = torch.load(PATH) # 同样是torch.load(PATH),在方法1中返回的是整个模型,在方法2中返回的是模型的参数及值。区别就在于当初PATH这个路径的文件里,保存的是整个模型还是只有模型的参数及值
model2.load_state_dict(state_dict)

reference:

The source code explains Pytorch's state_dict and load_state_dict in detail (only one method is mentioned, first save and then load)

Pytorch saves and loads the model (load and load_state_dict) (talks about 2 methods)

Guess you like

Origin blog.csdn.net/Mocode/article/details/131270058