Pytorch loading and saving models

First noted, with reference to this blog https://www.jianshu.com/p/4905bf8e06e5

  • method 1,

First serialization format can be mdl, pt, etc.

torch.save(model.state_dict(), MODEL_PATH)

Then deserialized, reload

model.load_state_dict(torch.load(MODEL_PATH))
  • Method 2,

Save the entire model, the format can be pth.tar

 torch.save(model, PATH)

Then load

model = torch.load(PATH)

There are some in the transmission method between the GPU and CPU, see the specific blog or official document

Guess you like

Origin www.cnblogs.com/yqpy/p/11497259.html