First noted, with reference to this blog https://www.jianshu.com/p/4905bf8e06e5
- method 1,
First serialization format can be mdl, pt, etc.
torch.save(model.state_dict(), MODEL_PATH)
Then deserialized, reload
model.load_state_dict(torch.load(MODEL_PATH))
- Method 2,
Save the entire model, the format can be pth.tar
torch.save(model, PATH)
Then load
model = torch.load(PATH)
There are some in the transmission method between the GPU and CPU, see the specific blog or official document