1. hook函数
为了节省显存(内存)
,PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除
(remove),以避免每次都运行钩子增加运行负载。
这里总结一下并给出实际用法和注意点。列举了常见的4种hook函数:
- 1、
Tensor.register_hook()
# 用来导出指定张量的梯度,或修改这个梯度值 - 2、
torch.nn.Module.register_forward_hook()
- 3、
torch.nn.Module.register_backward_hook()
- 4、
torch.nn.Module.register_forward_pre_hook()
2. hook函数说明
2.1 Tensor.register_hook()
用来导出指定张量的梯度,或修改这个梯度值。
注意:
-
- 上述代码是有效的,但如果写成 grad = grad * 2就失效了,因为此时没有对grad进行本地操作,新的grad 值没有传递给指定的梯度。保险起见,
最好在def
语句中写明re
- 上述代码是有效的,但如果写成 grad = grad * 2就失效了,因为此时没有对grad进行本地操作,新的grad 值没有传递给指定的梯度。保险起见,