How to save and call the network model trained by pytorch

Problem Description

The deep neural network model is very difficult to train, so the trained pytorch network model should be saved, and it can be called directly the next time it is used. How should this model be saved?

Solution

Pytorch mainly provides two methods, namely: the method of saving model parameters and the method of saving the entire model

Method 1: Only save model parameters

#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Method 2: Save the entire model

#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)

It should be noted that the recommended file name for the saved file is ×××.tar format. For example, the format of PATH can be set to:'./model_file_name/the_model_name.tar'

 

You can also refer to this article:

https://blog.csdn.net/weixin_41680653/article/details/93768559 

Guess you like

Origin blog.csdn.net/weixin_43450646/article/details/106931575