pytorch carga un archivo de modelo ya guardado y lo usa como un peso preentrenado para otra red
Descripción del problema
(Para tomar notas usted mismo) Guardé el modelo entrenado antes y ahora quiero cargar el modelo.
Solución
#加载保存好的模型
Layer1pre = torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt')
#定义自己的模型
model = CNNLayer(num_classes=10, aux_logits=True)
if use_gpu:
model = model.cuda()
#将模型权重更新到新的网络中
model.load_state_dict(Layer1pre, strict=False)
Enlace de referencia: https://blog.csdn.net/my_kingdom/article/details/85218478