Distributed training loda model error reporting

1. The problem of inconsistent error between distributed training load model and network parameter matching:

Question 1: The model is not a distributed model, but the model saved by load has the .module parameter.

solve:

new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    name = k[7:]  # 去除前面的 ".module"
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

The above code first creates an empty dictionarynew_state_dict, and then iterates through each key-value pair in the original state_dict. Remove the ".module" prefix from each key name and load the processed state_dict into the model.

The purpose of this is to ensure that when loading the model saved in distributed training, the ".module" prefix in the key name is removed to match the structure of the model.

Problem 2: The model is a distributed model, but the model saved by load does not have a .module parameter.

solve:

new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    name = 'module.' + k  # 在键名前添加 "module."
    new_state_dict[name] = v
 model.load_state_dict(new_state_dict)

The above code first creates an empty dictionarynew_state_dict, and then iterates through each key-value pair in the original state_dict. Add the "module." prefix to each key name and load the processed state_dict into the model.

The purpose of this is to ensure that when the model is loaded in distributed training, the "module." prefix in the key name is added back to match the structure of the model.

Note: Mainly focus on how the model is saved in torch.save.

Supongo que te gusta

Origin blog.csdn.net/m0_62278731/article/details/134749627
Recomendado
Clasificación