pytorch Gradient Clipping

Crop gradient (Gradient Clipping)

import torch.nn as nn

outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

nn.utils.clip_grad_norm_ Parameters:

    • parameters - based on a variable iterator, will be normalized gradient
    • max_norm - maximum norm of the gradient
    • norm_type - Type predetermined norm, L2 default

 

 

Guess you like

Origin www.cnblogs.com/Bella2017/p/11931131.html