【目标检测——OHEM 解读】处理类别不平衡问题

目录

前言

在接触一个新技术之前,肯定是因为遇到了新的难题,但这可以促使我们前进。

hard Negative Mining介绍

OHEM 概要解读

mmdetection中的OHEM

损失函数当中加入OHEM思想——图像分割损失函数OhemCELoss

文献引用:


前言

在接触一个新技术之前,肯定是因为遇到了新的难题,但这可以促使我们前进。

之前做个一个目标检测任务,每个类别之间的数据量差距较大,有明显的类别不均衡现象(当样本比例大于4:1时)。

解决类别不平衡问题现以有了较多的可行解决方案:

  • 对于目标检测,可以使用OHEM技术进行扩大数据广度。
  • 使用基于交叉熵损失改进的Focal Loss损失函数
  • 对类别多的数据类别进行欠采样(减少一些重复的数据)。
  • 对于类别较少的数据使用过采样,通过数据增广(色彩变换,仿射变换等)进行扩充。

本篇博客主要会记录在研究OHEM技术时的一些总结。


hard Negative Mining介绍

在two-stage检测算法中,RPN阶段会生成大量的检测框,由于很多时候一张图片可能只会有少量几个标注框(真实框),也就是说绝大部分检测框与真实框没有很大的交集,一般计算的IOU大于设置阈值时认为是正样本,小于设置阈值时是负样本。

但是这样选出来的框不一定是最容易错的框。

我们通常在生成检测框负样本中选出容易预测错误的(当成正样本的),作为新的数据集进行参与训练。

即hard Negative Mining(困难样本挖掘)。

思想:

你不会把所有错题都放到错题集中,只对当中最容易的错放入。

实现思想:

迭代地交替训练,用样本集更新模型,然后再固定模型,来选择分辨错的目标框并加入到样本集中继续训练。

缺点:

hard Negative Mining(困难样本挖掘)需要在不断的训练当中冻结参数、预测选出hard Negative再放入训的训练集,这大幅度的增加了工作量,加大了模型训练的时间。

注:一般使用 SVM 分类器才能使用此方法(SVM 分类器和 Hard Negative Mining Method 交替训练)


OHEM 概要解读

前言:

hard Negative Mining(困难样本挖掘)思想值得我们去使用和学习,但是我们试图在不影响效果的前提下去提高模型的迭代训练速度。故我们提出了OHEM(在线难例挖掘)。

论文:

1604.03540.pdf (arxiv.org)icon-default.png?t=M85Bhttps://arxiv.org/pdf/1604.03540.pdfOHEM(在线难例挖掘)流程概述:

1、进行一次的前向传播,获得每个Region proposal单独的损失值。
2、对每个Region proposal进行NMS计算。
3、对剩下的Region proposal按照损失值进行排序,然后选取损失最大的前一部分Region当做输入再次输入分类回归网络,对于训练多次loss还较高的我们可以认为其是困难样本。
4、将困难样本输入图中的(b)模块,(b)模块是(a)模块的复制版,(b)模块是用来反向传播的部分,然后吧更新的参数共享到(a)部分。

注:所谓的线上挖掘,就是先计算loss→筛选→得到困难负样本。


mmdetection中的OHEM

前言:

其实在mmdetection当中,已经封装好了OHEM的代码,但是大家可能都不知道他在哪,这里我给大家找一下他的位置。

彩蛋:

如何在mmdetection查找自己想要的东西(类或者类的调用等)


损失函数当中加入OHEM思想——图像分割损失函数OhemCELoss

前言:

虽然他在目标检测当中被提出,但是不仅仅是目标检测问题,其他问题都会出现类别不平衡的问题,我们试图把他应用到其他方向当中(例如语义分割)

代码实现:

class OhemCELoss(nn.Module):
    """
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
    """
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵
 
    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)     # 排序
        if loss[self.n_min] > self.thresh:       # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

文献引用:

(32条消息) 【每日一网】Day18:OHEM简单理解_陈子文好帅的博客-CSDN博客_ohem

(32条消息) OHEM 详解_*青云*的博客-CSDN博客_ohem

(32条消息) 图像分割损失函数OhemCELoss_超级无敌陈大佬的跟班的博客-CSDN博客_分割损失函数

猜你喜欢

转载自blog.csdn.net/m0_61139217/article/details/127084869