When PyTorch loads the model, an error occurs: RuntimeError: Error(s) in loading state_dict for *****: Missing key(s) in state_dict:

Problem Description:

    There is no breakpoint continuation in the original author's code. I added this function and introduced more parameters. When saving the model, I added epoch, net.state_dict(), optimizer.state_dict(), and scheduler.state_dict(). and other information.

The original code to save the model is as follows:

torch.save(net.state_dict(), model_dir)

After adding the information, the code to save the model is as follows:

 torch.save({'epoch': i,
             'model_state_dict': net.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),},
              model_dir)

The original code to load the model is as follows:

net.load_state_dict(torch.load(model_dir))

After adding the information, the code to load the model is as follows:

ckpt = torch.load(model_dir, map_location='cpu')
net.load_state_dict(ckpt['model_state_dict'])

When testing inference, an error is reported when loading the model:

net.load_state_dict(ckpt['model_state_dict'])
  File "/root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for *****:
	Missing key(s) in state_dict: 

Solution:

Method 1 can be loaded successfully, but some parameters will not be loaded, and in some cases, the inference results will be wrong.

ckpt = torch.load(model_dir)
model.load_state_dict(ckpt['model_state_dict'],strict=False)

Method 2: Replace module. in the dictionary key value, or compare the key printout of the original model pth file with the key of the current model, and manually load parameters for the model.

ckpt = torch.load(args.weights, map_location='cpu')
net.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['model_state_dict'].items()})

Root cause of the problem:

The following code was added to the training code:

net = nn.DataParallel(net)

Find net = nn.DataParallel(net) in the training code, comment it out and retrain.

Or use method 2 above to load the model.

For multi-GPU parallel computing, Pytorch’s nn.DataParallel can be used to train the same model.

Guess you like

Origin blog.csdn.net/mj412828668/article/details/130014232