How to handwrite softmax function to prevent numerical overflow?

When I write cross-entropy by hand, I find that errors sometimes occur? There is no problem with the entire mathematical calculation process. The main problem lies in overflow and underflow. That is, when encountering extremely large or extremely small logits, if you directly use the formula to perform softmax according to the exp method, numerical overflow will occur. Condition. To solve this problem, the first thing you need to do is subtract the maximum value, which is:

logits = logits - torch.max(logits, 1)[0][:, None]

You can see this link for the principle:
https://zhuanlan.zhihu.com/p/29376573

But after I subtract the maximum value, there will still be an overflow. At this time, after checking, it is found that 0 still appears after softmax. Then after passing the log function, it will become negative infinity. Do not write it by hand at this time:

torch.log(F.softmax(logits, dim=-1))

Instead, we directly use the log_softmax that comes with torch, which provides certain fault-tolerance control:

F.log_softmax(logits, dim=-1)

Or add a very small number when using log to prevent 0.

Guess you like

Origin blog.csdn.net/weixin_42988382/article/details/123284103