PyTorch loading model model.load_state_dict() problem, Unexpected key(s) in state_dict: "module.features.., Expected.

Hope to load the trained model onto the new network. As described in the topic above, PyTorch encountered a problem when loading the model parameters that were previously saved.

    Unexpected key(s) in state_dict: "module.features. ..."., Expected ".features....". The direct reason is that the key value name does not correspond.

    Indicates that during the loading process, the expected key value is feature..., not module.features.... This is caused by the process of model saving. The model should be in DataParallel mode, that is , the model is trained with multiple GPUs and then saved directly.

    You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

There are three ways to solve the above problem:

1. Create a new dictionary for the load model and remove the unnecessary key value "module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
    new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. Directly use blank'' instead of'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})

# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. The easiest way is to load the model and then to DataParallel the model, then load_state_dict can be used.

If there are multiple GPUs, parallelize the model and use DataParallel to operate. This process will add a "module. ***" to the key value.

model = VGG()# 实例化自己的模型;
checkpoint = torch.load('checkpoint.pt', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以;
if torch.cuda.device_count() > 1:
    # 如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个"module. ***"。
    model = nn.DataParallel(model) 
model.load_state_dict(checkpoint) # 接着就可以将模型参数load进模型。

 4. Summary

    It can be seen from the problem of error display that the key value does not match, so you can choose a variety of methods to load the model parameters. This method is usually encountered during load_state_dict. Transplant a trained network parameter to another network and continue training. Or load the trained network checkpoint into the model and perform training again. You can print out the model state_dict to see the difference between the two.

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

 

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
    print(k) 
print("*****************************************")

The output result is:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

It can be seen that there is a mismatch. In the parameters of the model, the key value is different and there is more module.

 

PS: 2020-12-25

In the process of porting parameters, for the parameters ending with .total_ops and .total_params, you can refer to the following code:

from collections import OrderedDict
checkpoint = torch.load(
    pretrained_model_file_path,
    map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if not k.endswith('total_ops') and not k.endswith('total_params'):
        name = k[7:]
        new_state_dict[name] = v

 

Guess you like

Origin blog.csdn.net/qq_32998593/article/details/89343507