简单理解Focal Loss

Focal Loss用来解决的是类别不均衡问题,其 α \alpha α变体的公式长下面这样: F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \mathrm{FL}\left(\mathrm{p}_{\mathrm{t}}\right)=-\alpha_{\mathrm{t}}\left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma \log \left(\mathrm{p}_{\mathrm{t}}\right) FL(pt)=αt(1pt)γlog(pt)


一点一点来看。考虑基本的二分类问题,交叉熵如下: C E ( p , y ) = { − log ⁡ ( p )  if  y = 1 − log ⁡ ( 1 − p )  if  y = 0 \mathrm{CE}(\mathrm{p}, \mathrm{y})= \begin{cases}-\log (\mathrm{p}) & \text { if } \mathrm{y}=1 \\ -\log (1-\mathrm{p}) & \text { if } \mathrm{y}=0\end{cases} CE(p,y)={ log(p)log(1p) if y=1 if y=0 其中 p p p为网络预测的结果。为了公式统一起见,记: p t = { p  if  y = 1 1 − p  if  y = 0 \mathrm{p_t}= \begin{cases}\mathrm{p} & \text { if } \mathrm{y}=1 \\ 1-\mathrm{p} & \text { if } \mathrm{y}=0\end{cases} pt={ p1p if y=1 if y=0 可以发现这一步把标签 y \mathrm{y} y p \mathrm{p} p的分类讨论情况都给统一了,现在交叉熵就可以直接写为: C E ( p , y ) = C E ( p t ) = − log ⁡ ( p t ) \mathrm{CE}(\mathrm{p}, \mathrm{y})=\mathrm{CE}\left(\mathrm{p}_{\mathrm{t}}\right)=-\log \left(\mathrm{p}_{\mathrm{t}}\right) CE(p,y)=CE(pt)=log(pt) 其中 t t t表类别,比如正类 1 1 1或者负类 0 0 0


通过上面这一部分的分析,我们可以把Focal Loss给改写为: F L ( p t ) = α t ( 1 − p t ) γ C E ( p t ) \mathrm{FL}\left(\mathrm{p}_{\mathrm{t}}\right)=\alpha_{\mathrm{t}}\left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma \mathrm{CE}\left(\mathrm{p}_{\mathrm{t}}\right) FL(pt)=αt(1pt)γCE(pt) 也就是一种改进的交叉熵,乘了两个系数, α t \alpha_{\mathrm{t}} αt ( 1 − p t ) γ \left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma (1pt)γ

首先来看 α t \alpha_{\mathrm{t}} αt。这一超参的动机非常直观,就是人工控制不同类别 t t t的权重。假设负类样本数量远多于正类,为了防止网络仅关注于负类的分类效果,我们就可以把负类的权重 α 0 \alpha_0 α0给调小一些,正类的权重 α 1 \alpha_1 α1给调大一些。这么做就初步解决了不均衡。


但是这里存在一个问题,手工设置的超参数可能不准确也不够灵活。有没有办法让网络自适应的学这个权重呢?我们进一步引入一个权重因子: ( 1 − p t ) γ \left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma (1pt)γ 可以发现,这个权重系数是与网络当前的状态 p t \mathrm{p}_{\mathrm{t}} pt有关的。写出其表达式: 1 − p t = { 1 − p  if  y = 1 p  if  y = 0 1 - \mathrm{p_t}= \begin{cases}1 - \mathrm{p} & \text { if } \mathrm{y}=1 \\ \mathrm{p} & \text { if } \mathrm{y}=0\end{cases} 1pt={ 1pp if y=1 if y=0 注意到 1 − p t 1 - \mathrm{p_t} 1pt蕴含着网络预测结果的错误率信息。该值越低,表明网络对该类的这个样本预测的更准确。基于"错题"是对于学习更有帮助的,可以得到Focal Loss的核心逻辑:

如果在训练阶段对某一类的样本总是比较准确,即 ( 1 − p t ) γ \left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma (1pt)γ更低,那么我们就将其作为系数对损失函数进行加权,使得网络不那么关注这一类的"容易样本"。反之,对于少见的困难类样本,其 ( 1 − p t ) γ \left(1-\mathrm{p}_{\mathrm{t}}\right)^\gamma (1pt)γ更高,网络对其给予更大的权重进行学习。

猜你喜欢

转载自blog.csdn.net/qq_40714949/article/details/127241724
今日推荐