earlystopping, torch.save saves the model file to the specified folder and names it dynamically
Problem Description
Overfitting may occur during the neural network training process. Using the method of early stopping, early termination of training in several epochs after the loss of the validation set does not increase or negatively increases can effectively avoid overfitting. In this process, the sum of Save the model file.
Solution
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0,layer=1):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.layer = layer
def __call__(self, val_loss,model,layer):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, layer)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {
self.counter} out of {
self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, layer)
self.counter = 0
def save_checkpoint(self, val_loss, model,layer):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({
self.val_loss_min:.6f} --> {
val_loss:.6f}). Saving model ...')
save_path = './ResultData_earlystop/savemodel/'
filepath = os.path.join(save_path, 'checkpoint_model_layer{}.pt'.format(self.layer))#
torch.save(model.state_dict(), filepath) # 这里会存储迄今最优模型的参数
self.val_loss_min = val_loss
early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
# early_stopping = EarlyStopping(patience=patience, verbose=True,layer=1)
early_stopping(EvalLoss,model,1)
if early_stopping.early_stop:
print("Early stopping")
break
# load the last checkpoint with the best model
model.load_state_dict(torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt'))