pytorch梯度截断之torch.nn.utils.clip_grad_norm_

当深度学习网络层数逐渐增加时,反向传播过程中链式法则里的梯度连乘项数也会随之增加,容易引起梯度消失和梯度爆炸。对于梯度爆炸,除了BN、shortcut、更换激活函数及权重正则化外,还有一个解决方法就是梯度剪裁,即设置一个梯度大小的上限。

torch.nn.utils.clip_grad_norm_

使用方法

在损失函数反向传播后(loss.backward())及参数更新前(optimizer.step())

函数定义

def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """

parameters:某组网络模型参数
max_norm:该组网络模型参数梯度的范数最大值
norm_type:范数类型,默认值为2

计算total_norm

如果norm_type为inf,则取输入的所有参数梯度范数中的最大值作为total_norm,

norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))

其他情况,则计算输入的所有参数梯度,stack成新向量,再对向量计算范数,作为total_norm,

total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)

计算clip_coef及clip_coef_clamped

比较max_norm与total_norm,

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

若比值小于1,即max_norm<total_norm,则将参数梯度乘以clip_coef_clamped;如果max_norm>total_norm,即没有溢出预设上限,则不对梯度进行修改,

p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))

猜你喜欢

转载自blog.csdn.net/qq_38964360/article/details/131423064
今日推荐