There are two ways
Save only the parameters of the model:
torch.save(the_model.state_dict(), PATH)
,
After the time needed to extract model parameters again:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
Save the entire model
torch.save(the_model, PATH)
When the extract:
the_model = torch.load(PATH)