pytorch gradient truncation of torch.nn.utils.clip_grad_norm_

When the number of deep learning network layers gradually increases, the number of gradient multiplication items in the chain rule in the backpropagation process will also increase, which will easily cause gradient disappearance and gradient explosion. For gradient explosion, in addition to BN, shortcut, replacement activation function and weight regularization, there is another solution is gradient clipping, that is, setting an upper limit on the size of the gradient.

torch.nn.utils.clip_grad_norm_

Instructions

After the loss function backpropagation (loss.backward()) and before the parameter update (optimizer.step())

function definition

def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """

parameters : a group of network model parameters
max_norm : the maximum norm value of the network model parameter gradient
norm_type : the norm type, the default value is 2

Calculate total_norm

If norm_type is inf, take the maximum value of all input parameter gradient norms as total_norm,

norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))

In other cases, calculate the gradient of all parameters input, stack into a new vector, and then calculate the norm of the vector as total_norm,

total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)

Calculate clip_coef and clip_coef_clamped

Compare max_norm with total_norm,

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

If the ratio is less than 1, that is, max_norm<total_norm, the parameter gradient is multiplied by clip_coef_clamped ; if max_norm>total_norm, that is, there is no overflow of the preset upper limit, the gradient is not modified.

p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))

Guess you like

Origin blog.csdn.net/qq_38964360/article/details/131423064