1. Salve o modelo de rede neural
O modelo Pytorch é salvo no formato pth. Para conveniência da operação, podemos definir dois modos de salvamento: salvar os dados do modelo de uma determinada rodada e salvar os dados do modelo com a menor função de perda no conjunto de verificação . Ao mesmo tempo, o registro da função de perda também pode ser salvo para análise. O código é implementado da seguinte forma:
torch.save(model.state_dict(), os.path.join(保存路径, "文件名"))
A função de perda pode ser salva em um arquivo txt
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
#可以添加图像绘制代码
O salvamento redondo e o melhor salvamento podem ser controlados adicionando código lógico.
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
#保存代码
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
#保存代码
2. Carregamento do modelo de rede neural
Geralmente, existem dois tipos de modelos carregados, um é o arquivo .pth salvo diretamente por pytorch e o outro é o arquivo geral .onnx . Entre eles, o onnx é um padrão de modelo de rede geral introduzido pela Microsoft, que pode realizar a implantação entre plataformas. O carregamento de arquivos onnx será descrito posteriormente.
#直接加载
self.net.load_state_dict(torch.load('文件路径', map_location=device))
#分层加载--可以检测不匹配的层
model_dict = model.state_dict() #加载模型字典用于比对
pretrained_dict = torch.load(model_path, map_location = device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k) #匹配成功的层
else:
no_load_key.append(k) #匹配失败的层
model_dict.update(temp_dict)
model.load_state_dict(model_dict)