[Deep Learning] Analysis of Classification Loss Function

[Deep Learning] Classification-related loss analysis

1 Introduction

In classification tasks, we usually use various loss functions to measure the difference between the model output and the true label. Sometimes it is unclear what to use. Here are several common classification-related loss functions, their analysis, and code examples .

Insert image description here

2. Analysis

  • Binary Cross Entropy Loss (BCELoss):
    torch.nn.BCELoss() is the loss function used for binary classification. It compares the probability of the model output with the binary value of the true label and calculates the binary cross-entropy loss. BCELoss can handle the case where each sample belongs to multiple categories. When using BCELoss, you need to pay attention to the fact that the model output is converted into the probability form of [0, 1] through the sigmoid activation function.

  • Binary Cross Entropy With Logits Loss (BCEWithLogitsLoss) with logits:
    torch.nn.BCEWithLogitsLoss() is a loss function similar to BCELoss. It applies both the sigmoid function and binary cross entropy loss. When using BCEWithLogitsLoss, there is no need to manually apply the sigmoid function to the model output because the function already performs this operation automatically internally.

  • Multiclass Cross Entropy Loss (CrossEntropyLoss):
    torch.nn.CrossEntropyLoss() is the loss function used for multi-class classification tasks. It compares the score of each category output by the model with the true label and calculates the cross-entropy loss. CrossEntropyLoss is suitable for cases where each sample can only belong to one category. Note that before using CrossEntropyLoss, you usually need to ensure that the model output passes through the softmax or log softmax function.

  • Multilabel Binary Cross Entropy Loss:
    When each sample can belong to multiple categories, we can use binary cross entropy loss to handle multi-label classification tasks. For each sample, the probability output by the model is compared with the true label, and the binary cross-entropy loss is calculated for each label. BCELoss can be applied to each label on a label-by-label basis, or using torch.nn.BCEWithLogitsLoss() and setting the last dimension in the model output to the number of labels.

3. Code examples

1) Binary cross-entropy loss (BCELoss):

import torch
import torch.nn as nn

# 模型输出经过 sigmoid 函数处理
model_output = torch.sigmoid(model(input))
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCELoss()
# 计算损失
loss = loss_fn(model_output, target)

2) Binary cross entropy loss with logits (BCEWithLogitsLoss):

import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_fn(model_output, target)

3) Multi-category cross-entropy loss (CrossEntropyLoss):

import torch
import torch.nn as nn

# 模型输出经过 softmax 函数处理
model_output = nn.functional.softmax(model(input), dim=1)
# 真实标签(每个样本只能属于一个类别)
target = torch.LongTensor([2, 1, 0])
# 创建损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(model_output, target)

4) Multilabel Binary Cross Entropy Loss:

import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([[0, 1], [1, 1], [1, 0], [0, 1]])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失,将模型输出的最后一个维度设置为标签数量
loss = loss_fn(model_output, target)

Guess you like

Origin blog.csdn.net/qq_51392112/article/details/132801845