Cómo guardar y llamar al modelo de red entrenado por pytorch

Descripción del problema

El modelo de red neuronal profunda es muy difícil de entrenar, por lo que el modelo de red de pytorch entrenado debe guardarse y se puede llamar directamente la próxima vez que se utilice. ¿Cómo se debe guardar este modelo?

Solución

Pytorch proporciona principalmente dos métodos, a saber: el método para guardar los parámetros del modelo y el método para guardar el modelo completo

Método 1: guardar solo los parámetros del modelo

#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Método 2: guarde el modelo completo

#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)

Debe tenerse en cuenta que el nombre de archivo recomendado para el archivo guardado es el formato .tar × ×

 

También puede consultar este artículo:

https://blog.csdn.net/weixin_41680653/article/details/93768559 

Supongo que te gusta

Origin blog.csdn.net/weixin_43450646/article/details/106931575
Recomendado
Clasificación