Record the troubleshooting process of Nan in half-precision amp training in PyTorch

According to the online tutorial, there are nan problems in half-precision amp training, which are nothing more than these types:

  • When calculating loss, there is a situation of dividing by 0
  • The loss is too large, and it is judged as inf by half precision
  • If there is nan in the network parameters, then the operation result will also output nan (this is more like a phenomenon rather than a reason. The appearance of nan in the network must be due to the occurrence of nan or inf before)

But in summary there are three types:

  • Operational errors, such as x/0 when calculating Loss, causing errors
  • Numerical overflow, the result of the operation exceeds the range of representation, for example, the weight and input are normal, but the result of the operation is Nan or Inf. For example, if the loss is too large, it actually exceeds the range of representation and becomes inf
  • Gradient problem, maybe there is a problem with the gradient return (I don’t know)

0. Conclusion

Let me talk about the conclusion first. I use amp half-precision training, that is, the float16 data type will be mixed in the middle to speed up the training process.

But Nan appears in this article because of float16, because the maximum value supported by float16 is 65504, and my model involves a matrix multiplication (actually the q@k operation in transformer). Among them, a∈[-38,40], b∈[-39,40], and matrix multiplication a@b=c, c∈[-61408,inf]. Because the maximum value after the matrix multiplication operation of a and b exceeds the maximum representation of float16, causing inf to appear, so the final result appears Nan.

1. Coarse positioning

A training process can be expressed as the following process:
insert image description here

1.1 Positioning to epoch

First of all, it can be seen that the output loss in epoch4 is normal, which means that the training process of 0~498iter in epoch4 is normal, then the problem may appear in the 501 iters of epoch4 499iter and epoch0~499iter.

1.2 locate to iter

Now we need to target the specific iter.

It can be judged according to the dichotomy method. In the debug model, 100iter, 300iter, and 499iter in the epoch=5 rounds respectively check whether the loss is normal, and so on to locate the specific iter.

Mine is between iter161~162 of epoch=5, the loss is normal when iter=161, and the loss is Nan when iter=162. Iter still follows the process in the above figure. It can be seen that the problem is nothing more than the gradient calculation and weight update when iter=161, and the forward operation and loss calculation of iter=162, these four places.

1.3 Locating to specific steps

When debugging, pause directly before the forward operation of epoch=5 and iter=162.

First look at whether the weight is normal:

# 在iter=162的模型推理之前,检查权重是否存在异常值,比如Nan或inf
if epoch == 5:
    if i == 162:
        print(epoch, i)

        class bcolors:
            HEADER = '\033[95m'
            OKBLUE = '\033[94m'
            OKGREEN = '\033[92m'
            WARNING = '\033[93m'
            FAIL = '\033[91m'
            ENDC = '\033[0m'
            BOLD = '\033[1m'
            UNDERLINE = '\033[4m'

        # print grad check
        v_n = []
        v_v = []
        v_g = []
        for name, parameter in model.named_parameters():
            v_n.append(name)
            v_v.append(parameter.detach().cpu().numpy() if parameter is not None else [0])
            v_g.append(parameter.grad.detach().cpu().numpy() if parameter.grad is not None else [0])
        for j in range(len(v_n)):
            if np.isnan(np.max(v_v[j]).item() - np.min(v_v[j]).item()) or np.isnan(
                    np.max(v_g[j]).item() - np.min(v_g[j]).item()):
                color = bcolors.FAIL + '*'
            else:
                color = bcolors.OKGREEN + ' '
            print('%svalue %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_v[j]).item(), np.max(v_v[j]).item()))
            print('%sgrad  %s: %.3e ~ %.3e' % (color, v_n[j], np.min(v_g[j]).item(), np.max(v_g[j]).item()))

outputs = model(images)

Through inspection, it is proved that there is no problem with the weight, so the problem is limited to the forward reasoning and loss calculation of iter=162

Check the input and output
via code:

print(images.mean())	# 检查输入,正常
outputs = model(images)
print(outputs .mean())	# 检查输出,Nan

From this we know the situation: the model weight is normal, the model input is normal, but the output of the model is Nan

2. Precise positioning

It's easy to do here. With the help of pycharm, we debug the input and output of each model in the model step by step to see which part of the model appears Nan or Inf, and finally locate a line of code:

attn = (q @ k.transpose(-2, -1)) * self.scale

This code is to realize the matrix multiplication of q and k, and their value ranges are:

tensor max (approximate value) min (approximate value)
q 38 -37
k 40 -38
attn inf -61408

It can be seen from here that it is a simple calculation problem, and a very common one is numerical overflow. Considering that I use half-precision float16, the maximum value is 65504 by querying, so it is very likely that the maximum value has overflowed. For verification, we can convert q and k to double (float64) before calculation, and we can find that the calculation result is normal, and the type is also float64. This shows that it is caused by numerical overflow.
insert image description here

3. Solutions

It is now known that my reason is numerical overflow. One method is to intercept: set inf or nan as a constant, and I normalize q and k to [-1,1] before the operation, which guarantees The result of the operation will not be too large (there is no reason, it is a brainless operation, and it is not recommended to learn).

Guess you like

Origin blog.csdn.net/qq_40243750/article/details/128207067