Pytorch中的checkPoint: torch.utils.checkpoint.checkpoint

torch.utils.checkpoint.checkpoint笔记,内容来源于官方手册
仅作笔记只用,不完整之处请查阅官方手册
https://pytorch.org/docs/stable/checkpoint.html

checkpoint是通过在backward期间为每个checkpoint段重新运行forward-pass segment来实现的。
这可能会导致像 RNG 状态这样的持久状态比没有checkpoint的情况更先进。默认情况下,checkpoint包括处理 RNG 状态的逻辑,以便与非checkpoint传递相比,使用 RNG 的checkpoint传递(例如通过 dropout)具有确定性输出。
根据checkpoint操作的运行时间,存储和恢复 RNG 状态的逻辑可能会导致一定的性能下降。

注:RNG状态指随机数状态

如果不需要确定性输出,则为checkpoint或 checkpoint_sequential 设置preserve_rng_state=False ,以在每个checkpoint期间省略存储和恢复RNG 状态。

换言之,1.9版本中,checkpoint能处理随机数状态了!

存储逻辑将当前设备和所有 cuda Tensor 参数的设备的 RNG 状态保存并恢复到 run_fn。但是,逻辑无法预测用户是否会将张量移动到 run_fn 本身内的新设备。因此,如果您在 run_fn 中将张量移动到一个新设备确定性输出永远无法保证。

尽量使用原生的 torch.utils.checkpoint.checkpoint

checkpoint的工作原理是用计算换取内存。与存储整个计算图的所有中间激活用于反向计算不同,checkpoint部分不保存中间激活,而是在反向传递中重新计算它们.它可以应用于模型的任何部分。

具体来说,在前向传递中,函数将以 torch.no_grad() 方式运行,即不存储中间激活。相反,前向传递保存输入元组和函数参数。在向后传递中,检索保存的输入和函数,并再次对函数计算前向传递,现在跟踪中间激活,然后使用这些激活值计算梯度。

函数的输出可以包含非 Tensor 值,并且仅对 Tensor 值执行梯度记录。请注意,如果输出包含由张量组成的嵌套结构(例如:自定义对象、列表、字典等),则这些嵌套在自定义结构中的张量将不会被视为 autograd 的一部分。

猜你喜欢

转载自blog.csdn.net/ftimes/article/details/120678872