Pytorch 钩子函数hook的使用

1. hook函数

为了节省显存(内存),PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除(remove),以避免每次都运行钩子增加运行负载。

这里总结一下并给出实际用法和注意点。列举了常见的4种hook函数:

  • 1、Tensor.register_hook() # 用来导出指定张量的梯度,或修改这个梯度值
  • 2、torch.nn.Module.register_forward_hook()
  • 3、torch.nn.Module.register_backward_hook()
  • 4、torch.nn.Module.register_forward_pre_hook()

2. hook函数说明

2.1 Tensor.register_hook()

用来导出指定张量的梯度,或修改这个梯度值。
在这里插入图片描述
注意:

    1. 上述代码是有效的,但如果写成 grad = grad * 2就失效了,因为此时没有对grad进行本地操作,新的grad 值没有传递给指定的梯度。保险起见,最好在def语句中写明re

猜你喜欢

转载自blog.csdn.net/weixin_38346042/article/details/129637343
今日推荐