1. Guarde el modelo de red neuronal
El modelo de Pytorch se guarda en el formato pth Para facilitar la operación, podemos establecer dos modos de guardado: guardar los datos del modelo de una determinada ronda y guardar los datos del modelo con la función de pérdida más pequeña en el conjunto de verificación . Al mismo tiempo, el registro de la función de pérdida también se puede guardar para su análisis. El código se implementa de la siguiente manera:
torch.save(model.state_dict(), os.path.join(保存路径, "文件名"))
La función de pérdida se puede guardar en un archivo txt
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
#可以添加图像绘制代码
El ahorro redondo y el mejor ahorro se pueden controlar agregando código lógico.
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
#保存代码
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
#保存代码
2. Carga del modelo de red neuronal
En general, hay dos tipos de modelos cargados, uno es el archivo .pth guardado directamente por pytorch y el otro es el archivo general .onnx . Entre ellos, onnx es un estándar de modelo de red general introducido por Microsoft, que puede realizar una implementación multiplataforma. La carga de archivos onnx se describirá más adelante.
#直接加载
self.net.load_state_dict(torch.load('文件路径', map_location=device))
#分层加载--可以检测不匹配的层
model_dict = model.state_dict() #加载模型字典用于比对
pretrained_dict = torch.load(model_path, map_location = device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k) #匹配成功的层
else:
no_load_key.append(k) #匹配失败的层
model_dict.update(temp_dict)
model.load_state_dict(model_dict)