Focal Loss原理解读以及Pytorch实现

论文:Focal Loss for Dense Object Detection
论文链接:https://arxiv.org/abs/1708.02002
Focal Loss这篇文章是He Kaiming和RBG发表在ICCV2017上的文章,详细的论文地址在上面,这里记录下我学习过程,下面我参考的几篇博客
Focal Loss
Focal Loss论文阅读 - Focal Loss for Dense Object Detection
如何评价Kaiming的Focal Loss for Dense Object Detection?

为什么要使用Focal Loss

Focal Loss是为了解决目标检测领域的一些问题而被提出来的,其 主要是为了解决样本类别不均衡问题(也有人说实际上也是解决了 gradient 被 easy example dominant 的问题)。我们知道object detection的算法主要可以分为两大类:two-stage detector和one-stage detector。前者是指类似Faster RCNN,RFCN这样需要region proposal的检测算法,这类算法可以达到很高的准确率,但是速度较慢。虽然可以通过减少proposal的数量或降低输入图像的分辨率等方式达到提速,但是速度并没有质的提升。后者是指类似YOLO,SSD这样不需要region proposal,直接回归的检测算法,这类算法速度很快,但是准确率不如前者。作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确率,同时不影响原有的速度

既然有了出发点,那么就要找one-stage detector的准确率不如two-stage detector的原因,作者认为原因是:样本的类别不均衡导致的。我们知道在object detection领域,一张图像可能生成成千上万的candidate locations,但是其中只有很少一部分是包含object的,这就带来了类别不均衡。那么类别不均衡会带来什么后果呢?引用原文讲的两个后果:(1) training is inefficient as most locations are easy negatives that contribute no useful learning signal; (2) en masse, the easy negatives can overwhelm training and lead to degenerate models. 什么意思呢?负样本数量太大,占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向并不是我们所希望的那样。以我YOLO举例子,比如在PASCAL VOC数据集中,每张图片上标注的目标可能也就几个。但是YOLO V2最后一层的输出是13×13×5,也就是845个候选目标!大量(简单易区分)的负样本在loss中占据了很大比重,使得有用的loss不能回传回来

因此针对类别不均衡问题,作者提出一种新的损失函数:focal loss,这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。为了证明focal loss的有效性,作者设计了一个dense detector:RetinaNet,并且在训练时采用focal loss训练。实验证明RetinaNet不仅可以达到one-stage detector的速度,也能有two-stage detector的准确率。

原理

Focal Loss从交叉熵损失而来。二分类的交叉熵损失如下:
C E ( p , y ) = { − l o g ( p ) , if y=1 − l o g ( 1 − p ) , otherwise (1) CE(p,y) = \begin{cases} −log(p), & \text{if y=1} \\ −log(1−p), & \text{otherwise} \end{cases} \tag{1} CE(p,y)={ log(p),log(1p),if y=1otherwise(1)
对应的,多分类的交叉熵损失是这样的:
C E ( p , y ) = − l o g ( p y ) (2) CE(p,y)=−log(py)\tag{2} CE(p,y)=log(py)(2)
因为是二分类,所以y的值是正1或负1,p的范围为0到1。当真实label是1,也就是y=1时,假如某个样本x预测为1这个类的概率p=0.6,那么损失就是 − l o g ( 0.6 ) -log(0.6) log(0.6),注意这个损失是大于等于0的。如果p=0.9,那么损失就是 − l o g ( 0.9 ) -log(0.9) log(0.9),所以p=0.6的损失要大于p=0.9的损失,这很容易理解。

如下图所示,比如说蓝色线为交叉熵损失函数随着 p t p_t pt变化的曲线( p t p_t pt意为ground truth,是标注类别所对应的概率)。可以看到,当概率大于0.5,即认为是易分类的简单样本时,值仍然较大。这样,很多简单样本累加起来,就很可能盖住那些稀少的不易正确分类的类别。
在这里插入图片描述
为了方便,用 p t p_t pt代替 p p p,如下公式2:。这里的pt就是上面图中中的横坐标,其实讲白了就是把二分类形式的交叉熵改成多分类的形式,原本的 p p p一般都是由sigmoid函数得来,一般都表示正样本的概率,而 p t p_t pt意标注类别所对应的概率。
p t = { p , if y=1 1 − p , otherwise (3) p_t = \begin{cases} p, & \text{if y=1} \\ 1−p, & \text{otherwise} \end{cases} \tag{3} pt={ p,1p,if y=1otherwise(3)

并且 C E ( p , y ) = C E ( p t ) = − l o g ( p t ) CE(p,y)=CE(p_t)=−log(p_t) CE(p,y)=CE(pt)=log(pt)
接下来介绍一个最基本的对交叉熵的改进,也将作为本文实验的baseline,如下公式4。什么意思呢?增加了一个系数 α t \alpha_t αt,跟 p t p_t pt的定义类似,表示不同类别的权重。当label=1的时候, α t = α \alpha_t=\alpha αt=α;当label=-1的时候, α t = 1 − α \alpha_t=1-\alpha αt=1α α \alpha α的范围也是0到1。因此可以通过设定 α \alpha α的值(一般而言假如1这个类的样本数比-1这个类的样本数多很多,那么 α \alpha α会取0到0.5来增加-1这个类的样本的权重)来控制正负样本对总的loss的共享权重。
C E ( p , y ) = − α t l o g ( p t ) (4) CE(p,y)=−\alpha_t log(p_t)\tag{4} CE(p,y)=αtlog(pt)(4)
显然前面的公式4虽然可以控制正负样本的权重,但是没法控制容易分类和难分类样本的权重,于是就有了focal loss:
F L ( p , y ) = − ( 1 − p t ) γ l o g ( p t ) (5) FL(p,y)=−(1-p_t)^\gamma log(p_t)\tag{5} FL(p,y)=(1pt)γlog(pt)(5)
这里的 γ \gamma γ称作focusing parameter, γ \gamma γ>=0, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ称为调制系数(modulating factor)
这里介绍下focal loss的两个重要性质:1、当一个样本被分错的时候,pt是很小的(请结合公式3,比如当y=1时,p要小于0.5才是错分类,此时pt就比较小,反之亦然),因此调制系数就趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。2、当γ=0的时候,focal loss就是传统的交叉熵损失,当γ增加的时候,调制系数也会增加。
focal loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失的贡献。
作者在实验中采用的是公式6的focal loss(结合了公式4和公式5,这样既能调整正负样本的权重,又能控制难易分类样本的权重):
F L ( p , y ) = − α t ( 1 − p t ) γ l o g ( p t ) (6) FL(p,y)=−\alpha_t(1-p_t)^\gamma log(p_t)\tag{6} FL(p,y)=αt(1pt)γlog(pt)(6)
在实验中a的选择范围也很广,一般而言当γ增加的时候,a需要减小一点(实验中γ=2,a=0.25的效果最好)

pytorch复现

下面贴出一个pytorch复现,需要说明的是这个复现是针对多分类问题的,当然二分类问题也是多分类问题中的一种,只不过如果想要运用在二分类问题上,最后的激活函数就不能使用sigmoid函数,只能使用softmax函数。

from torch import nn
import torch
from torch.nn import functional as F


class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重.      当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print("Focal_loss alpha = {},".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {}".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma

    def forward(self, preds_softmax, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B 批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]
        :return:
        """
        self.alpha = self.alpha.to(preds_softmax.device)
        preds_softmax = preds_softmax.view(-1,preds_softmax.size(-1))
        preds_logsoft = torch.log(preds_softmax)

        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
        alpha = self.alpha.gather(0,labels.view(-1))
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

        loss = torch.mul(alpha, loss.t())
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

猜你喜欢

转载自blog.csdn.net/weixin_41693877/article/details/106685350
今日推荐