Earlystopping, torch.save guarda el archivo del modelo en la carpeta especificada y lo nombra dinámicamente
Descripción del problema
El sobreajuste puede ocurrir durante el proceso de entrenamiento de la red neuronal. Usando el método de detención temprana, la terminación anticipada del entrenamiento en varias épocas después de la pérdida del conjunto de validación no aumenta o aumenta negativamente puede evitar efectivamente el sobreajuste. En este proceso, la suma de Guardar el archivo del modelo.
Solución
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0,layer=1):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.layer = layer
def __call__(self, val_loss,model,layer):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, layer)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {
self.counter} out of {
self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, layer)
self.counter = 0
def save_checkpoint(self, val_loss, model,layer):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({
self.val_loss_min:.6f} --> {
val_loss:.6f}). Saving model ...')
save_path = './ResultData_earlystop/savemodel/'
filepath = os.path.join(save_path, 'checkpoint_model_layer{}.pt'.format(self.layer))#
torch.save(model.state_dict(), filepath) # 这里会存储迄今最优模型的参数
self.val_loss_min = val_loss
early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
# early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
early_stopping(EvalLoss,model,1)
if early_stopping.early_stop:
print("Early stopping")
break
# load the last checkpoint with the best model
model.load_state_dict(torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt'))