Save and retrieve pytorch trained model

There are two ways

Save only the parameters of the model:

torch.save(the_model.state_dict(), PATH)

,
After the time needed to extract model parameters again:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Save the entire model

torch.save(the_model, PATH)

When the extract:

the_model = torch.load(PATH)
Released two original articles · won praise 9 · views 164

Guess you like

Origin blog.csdn.net/qq_45171138/article/details/104545944