Cuando PyTorch carga el modelo, se produce un error: RuntimeError: Error(es) al cargar state_dict para *****: Faltan claves en state_dict:

Descripción del problema:

    No hay continuación del punto de interrupción del entrenamiento en el código del autor original. Agregué esta función e introduje más parámetros. Al guardar el modelo, agregué epoch, net.state_dict(), optimizador.state_dict() y Scheduler.state_dict(). y otra información.

El código original para guardar el modelo es el siguiente:

torch.save(net.state_dict(), model_dir)

Luego de agregar la información, el código para guardar el modelo es el siguiente:

 torch.save({'epoch': i,
             'model_state_dict': net.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),},
              model_dir)

El código original para cargar el modelo es el siguiente:

net.load_state_dict(torch.load(model_dir))

Luego de agregar la información, el código para cargar el modelo es el siguiente:

ckpt = torch.load(model_dir, map_location='cpu')
net.load_state_dict(ckpt['model_state_dict'])

Al probar la inferencia, se informa un error al cargar el modelo:

net.load_state_dict(ckpt['model_state_dict'])
  File "/root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for *****:
	Missing key(s) in state_dict: 

Solución:

El método 1 se puede cargar correctamente, pero algunos parámetros no se cargarán y, en algunos casos, los resultados de la inferencia serán incorrectos.

ckpt = torch.load(model_dir)
model.load_state_dict(ckpt['model_state_dict'],strict=False)

Método 2: Reemplace módulo en el valor de clave del diccionario, o compare la impresión de clave del archivo pth del modelo original con la clave del modelo actual y cargue manualmente los parámetros para el modelo.

ckpt = torch.load(args.weights, map_location='cpu')
net.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['model_state_dict'].items()})

Fuente del problema:

Se agregó el siguiente código al código de capacitación:

net = nn.DataParallel(net)

Busque net = nn.DataParallel(net) en el código de entrenamiento, coméntelo y vuelva a entrenar.

O utilice el método 2 anterior para cargar el modelo.

Para la computación paralela con múltiples GPU, se puede utilizar nn.DataParallel de Pytorch para entrenar el mismo modelo.

Supongo que te gusta

Origin blog.csdn.net/mj412828668/article/details/130014232
Recomendado
Clasificación