pytorch使用——(十四)Hook函数

1、Hook Function

Hook函数机制:不改变主体,实现额外功能,像一个挂件,挂钩,hook

  • torch.Tensor.register_hook(hook)
  • torch.nn.Module.register_forward_hook
  • torch.nn.Module.register_forward_pre_hook
  • torch.nn.Module.register_backward_hook

2、Tensor.register_hook,功能:注册一个反向传播hook函数

Hook函数仅一个输入参数,为张量的梯度。

hook(grad) -> Tensor or None

3、 Module.register_forward_hook,功能:注册module前向传播hook函数

       hook(module, input, output) -> None

  • module: 当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

4、Module.register_forward_pre_hook,功能:注册module前向传播前的hook函数

hook(module, input) -> None

  • module: 当前网络层
  • nput:当前网络层输入数据

5、Module.register_backward_hook,注册module反向传播的hook函数

hook(module, grad_input, grad_output) -> Tensor or None

• module: 当前网络层
• grad_input:当前网络层输入梯度数据
• grad_output:当前网络层输出梯度数据

猜你喜欢

转载自blog.csdn.net/weixin_37799689/article/details/106486219
今日推荐