state_dictとload_state_dictの詳細Pytorchソース

Pytorchが保存され、ロードされた方法は以下のモデルでは:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

model.state_dict()実際には、のリターンOrderDict保存された名前とネットワーク構造の対応するパラメータは、どのようにソースコードの実装を見てみましょう。

state_dict

# torch.nn.modules.module.py
class Module(object):
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        if destination is None:
            destination = OrderedDict()
            destination._metadata = OrderedDict()
        destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
        for name, param in self._parameters.items():
            if param is not None:
                destination[prefix + name] = param if keep_vars else param.data
        for name, buf in self._buffers.items():
            if buf is not None:
                destination[prefix + name] = buf if keep_vars else buf.data
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
        for hook in self._state_dict_hooks.values():
            hook_result = hook(self, destination, prefix, local_metadata)
            if hook_result is not None:
                destination = hook_result
        return destination

State_dict関数は4、それぞれの要素を横断見ることができる_paramters_buffers_modules及び_state_dict_hooks前三の前に、記事の最後の読み取り、ある差を導入して、state_dictそれほど考慮されていない、通常は空場合願いが行わ。ノートへのもう一つのポイントは、それは、読んでいるModule読み取りモード再帰を使用し、相互の名前を使用すると.、後で容易にするため、セグメンテーションを行うload_state_dict読んパラメータを。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
        self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
        self.my_param = nn.Parameter(torch.randn(1))
        self.fc = nn.Linear(2,2,bias=False)
        self.conv = nn.Conv2d(2,1,1)
        self.fc2 = nn.Linear(2,2,bias=False)
        self.f3 = self.fc
    def forward(self, x):
        return x

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([-0.3052])), ('my_buffer', tensor([0.5583])), ('fc.weight', tensor([[ 0.6322, -0.0255],
        [-0.4747, -0.0530]])), ('conv.weight', tensor([[[[ 0.3346]],

         [[-0.2962]]]])), ('conv.bias', tensor([0.5205])), ('fc2.weight', tensor([[-0.4949,  0.2815],
        [ 0.3006,  0.0768]])), ('f3.weight', tensor([[ 0.6322, -0.0255],
        [-0.4747, -0.0530]]))])

我々は確かに最後の3つの出力パラメータを見ることができます。

load_state_dict

次のコードは、私たちが見る、二つの部分に分けることができ、

  1. load(self)

請求この関数は再帰的に、モデルパラメータを復元_load_from_state_dictソースは、テキストの端部に取り付けられています。

まず、クリアする必要があるstate_dict一方で、以前に保存されたシーケンス・パラメータ・モデルを示し、この変数を_load_from_state_dict関数local_stateモデル構造の表現、あなたのコード定義。

だから_load_from_state_dict、単純な理解の役割は、私たちが今の名前が必要な場合ということであるconv.weightサブモジュールを復元するには、引数として、再帰的に決定するためにconvかどうかstaet__dictlocal_state、入れない場合は、convに追加unexpected_keysそれ以外の場合は、再帰的に決定し、行くconv.weightがあるかどうか、が実行された場合はparam.copy_(input_param)、これが完了conv.weightパラメータコピーを。

  1. if strict:

この部分の役割は、上記のように、コピー・プロセス・パラメータがいるかどうかを決定することであるunexpected_keys、またはmissing_keys、エラーの場合は、コードを継続することができません。もちろん、もしstrict=Falseこれらの詳細は無視されます。

def load_state_dict(self, state_dict, strict=True):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(self)

    if strict:
        error_msg = ''
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in missing_keys)))

    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                           self.__class__.__name__, "\n\t".join(error_msgs)))
  • _load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
    local_state = {k: v.data for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]

            # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]

            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                  'the shape in current model is {}.'
                                  .format(key, input_param.shape, param.shape))
                continue

            if isinstance(input_param, Parameter):
                # backwards compatibility for serialized parameters
                input_param = input_param.data
            try:
                param.copy_(input_param)
            except Exception:
                error_msgs.append('While copying the parameter named "{}", '
                                  'whose dimensions in the model are {} and '
                                  'whose dimensions in the checkpoint are {}.'
                                  .format(key, param.size(), input_param.size()))
        elif strict:
            missing_keys.append(key)

    if strict:
        for key, input_param in state_dict.items():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules and input_name not in local_state:
                    unexpected_keys.append(key)


MARSGGBO オリジナル


興味を持っている場合、プライベートスタンプを歓迎

Eメール:[email protected]


2019年12月20日夜9時55分21秒



おすすめ

転載: www.cnblogs.com/marsggbo/p/12075356.html