[Pytorch] guarda y carga el modelo de red neuronal

1. Guarde el modelo de red neuronal

        El modelo de Pytorch se guarda en el formato pth Para facilitar la operación, podemos establecer dos modos de guardado: guardar los datos del modelo de una determinada ronda y guardar los datos del modelo con la función de pérdida más pequeña en el conjunto de verificación . Al mismo tiempo, el registro de la función de pérdida también se puede guardar para su análisis. El código se implementa de la siguiente manera:

torch.save(model.state_dict(), os.path.join(保存路径, "文件名"))

        La función de pérdida se puede guardar en un archivo txt

    def append_loss(self, epoch, loss, val_loss):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        self.val_loss.append(val_loss)

        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")

        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        #可以添加图像绘制代码

        El ahorro redondo y el mejor ahorro se pueden controlar agregando código lógico.

if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
    #保存代码
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
    #保存代码

2. Carga del modelo de red neuronal

En general, hay dos tipos de modelos cargados, uno es el archivo .pth         guardado directamente por pytorch y el otro es el archivo general .onnx . Entre ellos, onnx es un estándar de modelo de red general introducido por Microsoft, que puede realizar una implementación multiplataforma. La carga de archivos onnx se describirá más adelante.

#直接加载
self.net.load_state_dict(torch.load('文件路径', map_location=device))
#分层加载--可以检测不匹配的层
model_dict      = model.state_dict()    #加载模型字典用于比对
pretrained_dict = torch.load(model_path, map_location = device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
    if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
        temp_dict[k] = v
        load_key.append(k)       #匹配成功的层
    else:
        no_load_key.append(k)    #匹配失败的层
model_dict.update(temp_dict)
model.load_state_dict(model_dict)

Supongo que te gusta

Origin blog.csdn.net/weixin_37878740/article/details/128724667
Recomendado
Clasificación