[Target Detection - OHEM Interpretation] Dealing with Category Imbalance Problems

Table of contents

foreword

Before touching a new technology, it must be because of new difficulties encountered, but this can push us forward.

Hard Negative Mining Introduction

Interpretation of OHEM overview

OHEM in mmdetection

Add OHEM idea to the loss function - image segmentation loss function OhemCELoss

Literature citations:


foreword

Before touching a new technology, it must be because of new difficulties encountered, but this can push us forward.

Before doing a target detection task, there is a large gap in the amount of data between each category, and there is an obvious category imbalance (when the sample ratio is greater than 4:1).

Solving the class imbalance problem now has more feasible solutions:

  • For target detection, OHEM technology can be used to expand the data breadth.
  • Use the improved Focal Loss loss function based on cross-entropy loss .
  • Undersample data categories with many categories (reduce some repeated data).
  • For data with fewer categories , oversampling is used and augmented by data augmentation (color transformation, affine transformation, etc.).

This blog will mainly record some summaries when studying OHEM technology .


Hard Negative Mining Introduction

In the two-stage detection algorithm, the RPN stage will generate a large number of detection frames, because in many cases, a picture may only have a few labeled frames (real frames), which means that most of the detection frames are not very different from the real frames. For a large intersection, generally, when the calculated IOU is greater than the set threshold, it is considered a positive sample, and when it is smaller than the set threshold, it is considered a negative sample.

But the box selected in this way is not necessarily the most error-prone box.

We usually select the ones that are easy to predict wrong (as positive samples) from the negative samples of the generated detection frame, and use them as a new data set for training.

That is hard Negative Mining (difficult sample mining).

Thought:

You don't put all the wrong questions in the wrong question set, only the easiest mistakes among them.

Realize the idea:

Iteratively alternate training, update the model with the sample set, and then fix the model to select the wrong target frame and add it to the sample set to continue training.

shortcoming:

Hard Negative Mining (difficult sample mining) needs to freeze parameters during continuous training, predict and select hard Negatives, and then put them into the training set, which greatly increases the workload and increases the time for model training.

Note: This method can only be used when SVM classifier is used (SVM classifier and Hard Negative Mining Method are alternately trained)


Interpretation of OHEM overview

Foreword:

The idea of ​​hard Negative Mining (difficult sample mining) is worth using and learning, but we try to improve the iterative training speed of the model without affecting the effect. So we propose OHEM (Online Hard Example Mining).

paper:

1604.03540.pdf (arxiv.org) icon-default.png?t=M85Bhttps://arxiv.org/pdf/1604.03540.pdf OHEM (Online Hard Example Mining) Process Overview:

1. Perform a forward propagation to obtain the individual loss value of each Region proposal.
2. Perform NMS calculation for each Region proposal.
3. Sort the remaining Region proposals according to the loss value, and then select the previous part of the Region with the largest loss as the input and then input it into the classification regression network. We can consider it as a difficult sample for the high loss after training multiple times.
4. Input the difficult sample into the (b) module in the figure, (b) module is a copy of (a) module, (b) module is the part used for backpropagation, and then share the updated parameters to (a) part.

Note: The so-called online mining is to first calculate loss→screen→obtain difficult negative samples.


OHEM in mmdetection

Foreword:

In fact, in mmdetection, the code of OHEM has been encapsulated, but you may not know where he is, so I will find his location for you here.

Easter eggs:

How to find what you want in mmdetection (class or class call, etc.)


Add OHEM idea to the loss function - image segmentation loss function OhemCELoss

Foreword:

Although it was proposed in target detection, it is not only the target detection problem, but also other problems will have category imbalance. We try to apply it to other directions (such as semantic segmentation)

Code:

class OhemCELoss(nn.Module):
    """
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
    """
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵
 
    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)     # 排序
        if loss[self.n_min] > self.thresh:       # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

Literature citations:

(32 messages) [Daily Net] Day18: Simple understanding of OHEM_Chen Ziwen's handsome blog-CSDN blog_ohem

(32 messages) Detailed Explanation of OHEM_*Qingyun*'s Blog-CSDN Blog_ohem

(32 messages) Image Segmentation Loss Function OhemCELoss_Super Invincible Mr. Chen's Follower Blog-CSDN Blog_Segmentation Loss Function

Guess you like

Origin blog.csdn.net/m0_61139217/article/details/127084869