OhemCrossEntropyLoss (definition, step process, code implementation, OHEM, CE, OHEMCE, OHEMCELoss, OHEMLoss)

1. Definition of Ohem Cross Entropy Loss

OhemCrossEntropyLossIt is a loss function used for target detection tasks in deep learning. It is an improved version of the cross-entropy loss function for imbalanced data distribution and difficult sample training. OhemRepresents "Online Hard Example Mining", which means online difficult sample mining. In the target detection task, since the background class samples are usually far more than the target class samples, it leads to the problem of imbalance of data distribution, and some difficult samples are very challenging for network training. OhemCrossEntropyLossIt is designed to solve these problems.

The core idea of ​​this loss function is to select only those difficult samples with higher loss values ​​for gradient updates during the training process , thereby paying more attention to the samples that are difficult to classify, helping the network to better adapt to these samples and improve the performance of the model. .

Mathematically, OhemCrossEntropyLossthe definition of can be expressed by the following formula:

OhemCrossEntropyLoss = − 1 N ∑ i = 1 N { log ( p target ) if y target = 1 (target class sample) log ( 1 − p target ) if y target = 0 (background class sample and the loss is higher than the threshold) 0 otherwise \text{OhemCrossEntropyLoss} = - \frac{1}{N} \sum_{i=1}^{N} \begin{cases} \text{log}(p_{\text{target}}) & \text{ if } y_{\text{target}} = 1 \text{ (target class sample)} \\ \text{log}(1 - p_{\text{target}}) & \text{if } y_{\text {target}} = 0 \text{ (Background class sample and loss higher than threshold)} \\ 0 & \text{otherwise} \end{cases}OhemCrossEntropyLoss=N1i=1N log(ptarget)log(1ptarget)0if ytarget=1  ( target class sample )if ytarget=0  ( background class sample and loss higher than threshold )otherwise

Among them, NNN is the number of samples in Batch,p target p_{\text{target}}ptargetis the probability of the model predicting the target class, y target y_{\text{target}}ytargetis the real label (1 represents the target class, 0 represents the background class), and the loss calculation is processed differently depending on the label. Samples in background class samples whose loss value is higher than a predefined threshold will be selected for gradient update, so that the network pays more attention to samples that are difficult to classify, helping to improve performance.

It should be noted that OhemCrossEntropyLossdifficult samples need to be dynamically screened during the training process, so compared with traditional cross-entropy loss, its calculation is relatively complex. But when dealing with imbalanced data and difficult samples, it can improve the robustness and generalization ability of the model.

2. OHEM step process

  1. Give OhemCE Loss a threshold thresh:

    • Then the predicted probability of the pixel is > 0.7, then the pixel can be regarded as a simple sample and does not participate in the loss calculation.
    • Then the prediction probability of the pixel is < 0.7, then the pixel can be regarded as a difficult sample and participate in the loss calculation
  2. Determine the ignored pixel values lb_ignore: Generally, we set the value of the background to 255, that is, if the size of the pixel value is 255, then it will not participate in the loss calculation.

  3. Set the minimum number of pixels for calculation n_min: at least n_numpixels participate in the loss calculation (otherwise the network may stop updating).

To put it simply: the purpose of OHEM CrossEntropy Loss is to mine difficult samples; ignore simple samples .

3. Code implementation

import random
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    
class OhemCELoss(nn.Module):
    def __init__(self, thresh, lb_ignore=255, ignore_simple_sample_factor=16):
        """
            Args:
                thresh: 阈值,超过该值则被算法简单样本 -> 不参与Loss计算
                lb_ignore: 忽略的像素值(一般255代表背景), 不参与损失的计算
                ignore_simple_sample_factor: 忽略简单样本的系数
                                                    该系数越大,最少计算的像素点个数越少
                                                    该系数越小,最少计算的像素点个数越多
        """
        super(OhemCELoss, self).__init__()
        
        """
            这里的 thresh 和 self.thresh 不是一回儿事儿
                ①预测概率 > thresh -> 简单样本
                ①预测概率 < thresh -> 困难样本
                ②损失值 > self.thresh -> 困难样本
                ②损失值 < self.thresh -> 简单
                
                ①和②其实是一回儿事儿,但 thresh 和 self.thresh 不是一回儿事儿
        """
        self.thresh = -torch.log(input=torch.tensor(thresh, requires_grad=False, dtype=torch.float))
        self.lb_ignore = lb_ignore
        self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
        self.ignore_simple_sample_factor = ignore_simple_sample_factor
        
        """
            reduction 参数用于控制损失的计算方式和输出形式。它有三种可选的取值:
                1. 'none':当设置为 'none' 时,损失将会逐个样本计算,返回一个与输入张量相同形状的损失张量。
                           这意味着输出的损失张量的形状与输入的标签张量相同,每个位置对应一个样本的损失值。
                2. 'mean':当设置为 'mean' 时,损失会对逐个样本计算的损失进行求均值,得到一个标量值。
                           即计算所有样本的损失值的平均值。
                3. 'sum' : 当设置为 'sum'  时,损失会对逐个样本计算的损失进行求和,得到一个标量值。
                           即计算所有样本的损失值的总和。

            在语义分割任务中,通常使用 ignore_index 参数来忽略某些特定标签,例如背景类别。
            当计算损失时,将会忽略这些特定标签的损失计算,以避免这些标签对损失的影响。
            如果设置了 ignore_index 参数,'none' 的 reduction 参数会很有用,因为它可以让你获取每个样本的损失,包括被忽略的样本。

            总之,reduction 参数允许在计算损失时控制输出形式,以满足不同的需求。
        """
    
    def forward(self, logits, labels):
        # 1. 计算 n_min(至少算多少个像素点)
        n_min = labels[labels != self.lb_ignore].numel() // self.ignore_simple_sample_factor
        
        # 2. 使用 CrossEntropy 计算损失, 之后再将其展平
        loss = self.criteria(logits, labels).view(-1)
        
        # 3. 选出所有loss中大于self.thresh的像素点 -> 困难样本
        loss_hard = loss[loss > self.thresh]
        
        # 4. 如果总数小于 n_min, 那么肯定要保证有 n_min 个像素点的 loss
        if loss_hard.numel() < n_min:
            loss_hard, _ = loss.topk(n_min)
            
        # 5. 如果参与的像素点的个数 > n_min 个,那么这些点都参与计算
        loss_hard_mean = torch.mean(loss_hard)
        
        # 6. 返回损失的均值
        return loss_hard_mean
    
    
if __name__ == "__main__":
    setup_seed(20)
    
    # 1. 生成预测值(假设我们有两个样本,每个样本有 3 个类别,高度和宽度均为 4)
    logits = Variable(torch.randn(2, 3, 4, 4))  # [N, C, H, W], s.t. C <-> num_classes
    
    # 2. 生成真实标签(每个样本的标签是一个 4x4 的图像)
    labels = Variable(torch.randint(low=0, high=3, size=(2, 4, 4)))  # [N, H, W]
    
    # 3. 初始化:创建 OhemCELoss 的实例,阈值设置为 0.7
    ohem_criterion = OhemCELoss(thresh=0.7, lb_ignore=255, ignore_simple_sample_factor=16)
    
    # 4. 计算 Ohem 损失
    loss = ohem_criterion(logits, labels)
    
    print(f"Ohem Loss: {
      
      loss.item()}")  # Ohem Loss: 1.3310734033584595

knowledge source

  1. https://www.bilibili.com/video/BV12841117yo
  2. https://www.bilibili.com/video/BV1Um4y1L753

Guess you like

Origin blog.csdn.net/weixin_44878336/article/details/132300255