Descripción del problema
El modelo de red neuronal profunda es muy difícil de entrenar, por lo que el modelo de red de pytorch entrenado debe guardarse y se puede llamar directamente la próxima vez que se utilice. ¿Cómo se debe guardar este modelo?
Solución
Pytorch proporciona principalmente dos métodos, a saber: el método para guardar los parámetros del modelo y el método para guardar el modelo completo
Método 1: guardar solo los parámetros del modelo
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
Método 2: guarde el modelo completo
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)
Debe tenerse en cuenta que el nombre de archivo recomendado para el archivo guardado es el formato .tar × ×
También puede consultar este artículo:
https://blog.csdn.net/weixin_41680653/article/details/93768559