EarlyStopping, torch.save guarda el archivo del modelo en la carpeta especificada y lo nombra dinámicamente

Earlystopping, torch.save guarda el archivo del modelo en la carpeta especificada y lo nombra dinámicamente


Descripción del problema

El sobreajuste puede ocurrir durante el proceso de entrenamiento de la red neuronal. Usando el método de detención temprana, la terminación anticipada del entrenamiento en varias épocas después de la pérdida del conjunto de validación no aumenta o aumenta negativamente puede evitar efectivamente el sobreajuste. En este proceso, la suma de Guardar el archivo del modelo.

Solución

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0,layer=1):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.layer = layer
        

    def __call__(self, val_loss,model,layer):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, layer)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {
      
      self.counter} out of {
      
      self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, layer)
            self.counter = 0

    def save_checkpoint(self, val_loss, model,layer):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({
      
      self.val_loss_min:.6f} --> {
      
      val_loss:.6f}).  Saving model ...')
        save_path = './ResultData_earlystop/savemodel/'
        filepath = os.path.join(save_path, 'checkpoint_model_layer{}.pt'.format(self.layer))#
        torch.save(model.state_dict(), filepath)	# 这里会存储迄今最优模型的参数
        self.val_loss_min = val_loss
early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
# early_stopping needs the validation loss to check if it has decresed, 
            # and if it has, it will make a checkpoint of the current model
            # early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
            early_stopping(EvalLoss,model,1)
            
            if early_stopping.early_stop:
                print("Early stopping")
                break
            
            # load the last checkpoint with the best model
            model.load_state_dict(torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt'))

Supongo que te gusta

Origin blog.csdn.net/qq_38703529/article/details/122203101
Recomendado
Clasificación