Crop gradient (Gradient Clipping)
import torch.nn as nn outputs = model(data) loss= loss_fn(outputs, target) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) optimizer.step()
nn.utils.clip_grad_norm_
Parameters:
- parameters - based on a variable iterator, will be normalized gradient
- max_norm - maximum norm of the gradient
- norm_type - Type predetermined norm, L2 default