开激活重计算
GPU利用率一般在 0.3 - 0.55 之间,假定为0.45
4090 理论性能:FP16:82.58 TFLOPS
不开激活重计算
我们来说一下系数8或6是怎么来的:
- 对于每个模型参数,都进行2次浮点数计算,即计算Y = AB 时,先将元素按位相乘,再按位相加,因此每个参数都需要进行两次浮点数运算。
- 反向传播的计算量是前向传播时的两倍
个人理解,对每个参数而言,反向传播时需要计算:一阶导数、二阶导数、梯度累积、参数更新。总共四次运算,前向传播只需要计算两次,因此计算量是前向计算的两倍 - 开激活重计算时,反向传播时需要额外进行一次前向传播
因此
前向传递 + 后向传递 + 激活重计算的系数 = 1 + 2 + 1 = 4
使用激活重计算的一次训练迭代中,对于每个token,每个模型参数,需要进行
2 x 4 = 8 次浮点数运算
想更具体了解反向传播的计算量是前向计算的两倍的同学的可以查阅以下文章:
1、What’s the backward-forward FLOP ratio for Neural Networks?
2、How the backpropagation algorithm works
注意: 内存比较小时再开激活重计算,若内存充足则没必要开激活重计算了