Tutorials > Saving and Loading Models

# -*- coding: utf-8 -*-
"""
Saving and Loading Models
=========================
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_

This document provides solutions to a variety of use cases regarding the
saving and loading of PyTorch models. Feel free to read the whole
document, or just skip to the code you need for a desired use case.

When it comes to saving and loading models, there are three core
functions to be familiar with:

1) `torch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`__:
   Saves a serialized object to disk. This function uses Python’s
   `pickle <https://docs.python.org/3/library/pickle.html>`__ utility
   for serialization. Models, tensors, and dictionaries of all kinds of
   objects can be saved using this function.

2) `torch.load <https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load>`__:
   Uses `pickle <https://docs.python.org/3/library/pickle.html>`__\ ’s
   unpickling facilities to deserialize pickled object files to memory.
   This function also facilitates the device to load the data into (see
   `Saving & Loading Model Across
   Devices <#saving-loading-model-across-devices>`__).

3) `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/nn.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`__:
   Loads a model’s parameter dictionary using a deserialized
   *state_dict*. For more information on *state_dict*, see `What is a
   state_dict? <#what-is-a-state-dict>`__.



**Contents:**

-  `What is a state_dict? <#what-is-a-state-dict>`__
-  `Saving & Loading Model for
   Inference <#saving-loading-model-for-inference>`__
-  `Saving & Loading a General
   Checkpoint <#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training>`__
-  `Saving Multiple Models in One
   File <#saving-multiple-models-in-one-file>`__
-  `Warmstarting Model Using Parameters from a Different
   Model <#warmstarting-model-using-parameters-from-a-different-model>`__
-  `Saving & Loading Model Across
   Devices <#saving-loading-model-across-devices>`__

"""


######################################################################
# What is a ``state_dict``?
# -------------------------
#
# In PyTorch, the learnable parameters (i.e. weights and biases) of an
# ``torch.nn.Module`` model is contained in the model’s *parameters*
# (accessed with ``model.parameters()``). A *state_dict* is simply a
# Python dictionary object that maps each layer to its parameter tensor.
# Note that only layers with learnable parameters (convolutional layers,
# linear layers, etc.) have entries in the model’s *state_dict*. Optimizer
# objects (``torch.optim``) also have a *state_dict*, which contains
# information about the optimizer’s state, as well as the hyperparameters
# used.
#
# Because *state_dict* objects are Python dictionaries, they can be easily
# saved, updated, altered, and restored, adding a great deal of modularity
# to PyTorch models and optimizers.
#
# Example:
# ^^^^^^^^
#
# Let’s take a look at the *state_dict* from the simple model used in the
# `Training a
# classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py>`__
# tutorial.
#
# .. code:: python
#
#    # Define model
#    class TheModelClass(nn.Module):
#        def __init__(self):
#            super(TheModelClass, self).__init__()
#            self.conv1 = nn.Conv2d(3, 6, 5)
#            self.pool = nn.MaxPool2d(2, 2)
#            self.conv2 = nn.Conv2d(6, 16, 5)
#            self.fc1 = nn.Linear(16 * 5 * 5, 120)
#            self.fc2 = nn.Linear(120, 84)
#            self.fc3 = nn.Linear(84, 10)
#
#        def forward(self, x):
#            x = self.pool(F.relu(self.conv1(x)))
#            x = self.pool(F.relu(self.conv2(x)))
#            x = x.view(-1, 16 * 5 * 5)
#            x = F.relu(self.fc1(x))
#            x = F.relu(self.fc2(x))
#            x = self.fc3(x)
#            return x
#
#    # Initialize model
#    model = TheModelClass()
#
#    # Initialize optimizer
#    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#
#    # Print model's state_dict
#    print("Model's state_dict:")
#    for param_tensor in model.state_dict():
#        print(param_tensor, "\t", model.state_dict()[param_tensor].size())
#
#    # Print optimizer's state_dict
#    print("Optimizer's state_dict:")
#    for var_name in optimizer.state_dict():
#        print(var_name, "\t", optimizer.state_dict()[var_name])
#
# **Output:**
#
# ::
#
#    Model's state_dict:
#    conv1.weight     torch.Size([6, 3, 5, 5])
#    conv1.bias   torch.Size([6])
#    conv2.weight     torch.Size([16, 6, 5, 5])
#    conv2.bias   torch.Size([16])
#    fc1.weight   torch.Size([120, 400])
#    fc1.bias     torch.Size([120])
#    fc2.weight   torch.Size([84, 120])
#    fc2.bias     torch.Size([84])
#    fc3.weight   torch.Size([10, 84])
#    fc3.bias     torch.Size([10])
#
#    Optimizer's state_dict:
#    state    {}
#    param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
#


######################################################################
# Saving & Loading Model for Inference
# ------------------------------------
#
# Save/Load ``state_dict`` (Recommended)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model.state_dict(), PATH)
#
# **Load:**
#
# .. code:: python
#
#    model = TheModelClass(*args, **kwargs)
#    model.load_state_dict(torch.load(PATH))
#    model.eval()
#
# When saving a model for inference, it is only necessary to save the
# trained model’s learned parameters. Saving the model’s *state_dict* with
# the ``torch.save()`` function will give you the most flexibility for
# restoring the model later, which is why it is the recommended method for
# saving models.
#
# A common PyTorch convention is to save models using either a ``.pt`` or
# ``.pth`` file extension.
#
# Remember that you must call ``model.eval()`` to set dropout and batch
# normalization layers to evaluation mode before running inference.
# Failing to do this will yield inconsistent inference results.
#
# .. Note ::
#
#    Notice that the ``load_state_dict()`` function takes a dictionary
#    object, NOT a path to a saved object. This means that you must
#    deserialize the saved *state_dict* before you pass it to the
#    ``load_state_dict()`` function. For example, you CANNOT load using
#    ``model.load_state_dict(PATH)``.
#
#
# Save/Load Entire Model
# ^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model, PATH)
#
# **Load:**
#
# .. code:: python
#
#    # Model class must be defined somewhere
#    model = torch.load(PATH)
#    model.eval()
#
# This save/load process uses the most intuitive syntax and involves the
# least amount of code. Saving a model in this way will save the entire
# module using Python’s
# `pickle <https://docs.python.org/3/library/pickle.html>`__ module. The
# disadvantage of this approach is that the serialized data is bound to
# the specific classes and the exact directory structure used when the
# model is saved. The reason for this is because pickle does not save the
# model class itself. Rather, it saves a path to the file containing the
# class, which is used during load time. Because of this, your code can
# break in various ways when used in other projects or after refactors.
#
# A common PyTorch convention is to save models using either a ``.pt`` or
# ``.pth`` file extension.
#
# Remember that you must call ``model.eval()`` to set dropout and batch
# normalization layers to evaluation mode before running inference.
# Failing to do this will yield inconsistent inference results.
#


######################################################################
# Saving & Loading a General Checkpoint for Inference and/or Resuming Training
# ----------------------------------------------------------------------------
#
# Save:
# ^^^^^
#
# .. code:: python
#
#    torch.save({
#                'epoch': epoch,
#                'model_state_dict': model.state_dict(),
#                'optimizer_state_dict': optimizer.state_dict(),
#                'loss': loss,
#                ...
#                }, PATH)
#
# Load:
# ^^^^^
#
# .. code:: python
#
#    model = TheModelClass(*args, **kwargs)
#    optimizer = TheOptimizerClass(*args, **kwargs)
#
#    checkpoint = torch.load(PATH)
#    model.load_state_dict(checkpoint['model_state_dict'])
#    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#    epoch = checkpoint['epoch']
#    loss = checkpoint['loss']
#
#    model.eval()
#    # - or -
#    model.train()
#
# When saving a general checkpoint, to be used for either inference or
# resuming training, you must save more than just the model’s
# *state_dict*. It is important to also save the optimizer’s *state_dict*,
# as this contains buffers and parameters that are updated as the model
# trains. Other items that you may want to save are the epoch you left off
# on, the latest recorded training loss, external ``torch.nn.Embedding``
# layers, etc.
#
# To save multiple components, organize them in a dictionary and use
# ``torch.save()`` to serialize the dictionary. A common PyTorch
# convention is to save these checkpoints using the ``.tar`` file
# extension.
#
# To load the items, first initialize the model and optimizer, then load
# the dictionary locally using ``torch.load()``. From here, you can easily
# access the saved items by simply querying the dictionary as you would
# expect.
#
# Remember that you must call ``model.eval()`` to set dropout and batch
# normalization layers to evaluation mode before running inference.
# Failing to do this will yield inconsistent inference results. If you
# wish to resuming training, call ``model.train()`` to ensure these layers
# are in training mode.
#


######################################################################
# Saving Multiple Models in One File
# ----------------------------------
#
# Save:
# ^^^^^
#
# .. code:: python
#
#    torch.save({
#                'modelA_state_dict': modelA.state_dict(),
#                'modelB_state_dict': modelB.state_dict(),
#                'optimizerA_state_dict': optimizerA.state_dict(),
#                'optimizerB_state_dict': optimizerB.state_dict(),
#                ...
#                }, PATH)
#
# Load:
# ^^^^^
#
# .. code:: python
#
#    modelA = TheModelAClass(*args, **kwargs)
#    modelB = TheModelBClass(*args, **kwargs)
#    optimizerA = TheOptimizerAClass(*args, **kwargs)
#    optimizerB = TheOptimizerBClass(*args, **kwargs)
#
#    checkpoint = torch.load(PATH)
#    modelA.load_state_dict(checkpoint['modelA_state_dict'])
#    modelB.load_state_dict(checkpoint['modelB_state_dict'])
#    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
#    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
#
#    modelA.eval()
#    modelB.eval()
#    # - or -
#    modelA.train()
#    modelB.train()
#
# When saving a model comprised of multiple ``torch.nn.Modules``, such as
# a GAN, a sequence-to-sequence model, or an ensemble of models, you
# follow the same approach as when you are saving a general checkpoint. In
# other words, save a dictionary of each model’s *state_dict* and
# corresponding optimizer. As mentioned before, you can save any other
# items that may aid you in resuming training by simply appending them to
# the dictionary.
#
# A common PyTorch convention is to save these checkpoints using the
# ``.tar`` file extension.
#
# To load the models, first initialize the models and optimizers, then
# load the dictionary locally using ``torch.load()``. From here, you can
# easily access the saved items by simply querying the dictionary as you
# would expect.
#
# Remember that you must call ``model.eval()`` to set dropout and batch
# normalization layers to evaluation mode before running inference.
# Failing to do this will yield inconsistent inference results. If you
# wish to resuming training, call ``model.train()`` to set these layers to
# training mode.
#


######################################################################
# Warmstarting Model Using Parameters from a Different Model
# ----------------------------------------------------------
#
# Save:
# ^^^^^
#
# .. code:: python
#
#    torch.save(modelA.state_dict(), PATH)
#
# Load:
# ^^^^^
#
# .. code:: python
#
#    modelB = TheModelBClass(*args, **kwargs)
#    modelB.load_state_dict(torch.load(PATH), strict=False)
#
# Partially loading a model or loading a partial model are common
# scenarios when transfer learning or training a new complex model.
# Leveraging trained parameters, even if only a few are usable, will help
# to warmstart the training process and hopefully help your model converge
# much faster than training from scratch.
#
# Whether you are loading from a partial *state_dict*, which is missing
# some keys, or loading a *state_dict* with more keys than the model that
# you are loading into, you can set the ``strict`` argument to **False**
# in the ``load_state_dict()`` function to ignore non-matching keys.
#
# If you want to load parameters from one layer to another, but some keys
# do not match, simply change the name of the parameter keys in the
# *state_dict* that you are loading to match the keys in the model that
# you are loading into.
#


######################################################################
# Saving & Loading Model Across Devices
# -------------------------------------
#
# Save on GPU, Load on CPU
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model.state_dict(), PATH)
#
# **Load:**
#
# .. code:: python
#
#    device = torch.device('cpu')
#    model = TheModelClass(*args, **kwargs)
#    model.load_state_dict(torch.load(PATH, map_location=device))
#
# When loading a model on a CPU that was trained with a GPU, pass
# ``torch.device('cpu')`` to the ``map_location`` argument in the
# ``torch.load()`` function. In this case, the storages underlying the
# tensors are dynamically remapped to the CPU device using the
# ``map_location`` argument.
#
# Save on GPU, Load on GPU
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model.state_dict(), PATH)
#
# **Load:**
#
# .. code:: python
#
#    device = torch.device("cuda")
#    model = TheModelClass(*args, **kwargs)
#    model.load_state_dict(torch.load(PATH))
#    model.to(device)
#    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
#
# When loading a model on a GPU that was trained and saved on GPU, simply
# convert the initialized ``model`` to a CUDA optimized model using
# ``model.to(torch.device('cuda'))``. Also, be sure to use the
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
# the data for the model. Note that calling ``my_tensor.to(device)``
# returns a new copy of ``my_tensor`` on GPU. It does NOT overwrite
# ``my_tensor``. Therefore, remember to manually overwrite tensors:
# ``my_tensor = my_tensor.to(torch.device('cuda'))``.
#
# Save on CPU, Load on GPU
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model.state_dict(), PATH)
#
# **Load:**
#
# .. code:: python
#
#    device = torch.device("cuda")
#    model = TheModelClass(*args, **kwargs)
#    model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
#    model.to(device)
#    # Make sure to call input = input.to(device) on any input tensors that you feed to the model
#
# When loading a model on a GPU that was trained and saved on CPU, set the
# ``map_location`` argument in the ``torch.load()`` function to
# *cuda:device_id*. This loads the model to a given GPU device. Next, be
# sure to call ``model.to(torch.device('cuda'))`` to convert the model’s
# parameter tensors to CUDA tensors. Finally, be sure to use the
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
# the data for the CUDA optimized model. Note that calling
# ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU. It
# does NOT overwrite ``my_tensor``. Therefore, remember to manually
# overwrite tensors: ``my_tensor = my_tensor.to(torch.device('cuda'))``.
#
# Saving ``torch.nn.DataParallel`` Models
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# **Save:**
#
# .. code:: python
#
#    torch.save(model.module.state_dict(), PATH)
#
# **Load:**
#
# .. code:: python
#
#    # Load to whatever device you want
#
# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU
# utilization. To save a ``DataParallel`` model generically, save the
# ``model.module.state_dict()``. This way, you have the flexibility to
# load the model any way you want to any device you want.
#

猜你喜欢

转载自www.cnblogs.com/houjun/p/10214452.html
今日推荐