Descripción del problema:
No hay continuación del punto de interrupción del entrenamiento en el código del autor original. Agregué esta función e introduje más parámetros. Al guardar el modelo, agregué epoch, net.state_dict(), optimizador.state_dict() y Scheduler.state_dict(). y otra información.
El código original para guardar el modelo es el siguiente:
torch.save(net.state_dict(), model_dir)
Luego de agregar la información, el código para guardar el modelo es el siguiente:
torch.save({'epoch': i,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),},
model_dir)
El código original para cargar el modelo es el siguiente:
net.load_state_dict(torch.load(model_dir))
Luego de agregar la información, el código para cargar el modelo es el siguiente:
ckpt = torch.load(model_dir, map_location='cpu')
net.load_state_dict(ckpt['model_state_dict'])
Al probar la inferencia, se informa un error al cargar el modelo:
net.load_state_dict(ckpt['model_state_dict'])
File "/root/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for *****:
Missing key(s) in state_dict:
Solución:
El método 1 se puede cargar correctamente, pero algunos parámetros no se cargarán y, en algunos casos, los resultados de la inferencia serán incorrectos.
ckpt = torch.load(model_dir)
model.load_state_dict(ckpt['model_state_dict'],strict=False)
Método 2: Reemplace módulo en el valor de clave del diccionario, o compare la impresión de clave del archivo pth del modelo original con la clave del modelo actual y cargue manualmente los parámetros para el modelo.
ckpt = torch.load(args.weights, map_location='cpu')
net.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['model_state_dict'].items()})
Fuente del problema:
Se agregó el siguiente código al código de capacitación:
net = nn.DataParallel(net)
Busque net = nn.DataParallel(net) en el código de entrenamiento, coméntelo y vuelva a entrenar.
O utilice el método 2 anterior para cargar el modelo.
Para la computación paralela con múltiples GPU, se puede utilizar nn.DataParallel de Pytorch para entrenar el mismo modelo.