CrossEntropy (cross entropy loss function pytorch)

introduce

The crossentropy loss function is mainly used for multi-classification tasks . It calculates the cross-entropy loss between the model output and the true label , which can be used as an objective function for model optimization.

In a multi-classification task, each sample has multiple possible categories, and the model outputs a probability distribution that each sample belongs to each category . The cross-entropy loss function can measure the distance between the probability distribution of the model output and the true label, thereby guiding model optimization.

Usage of Pytorch library

class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

Parameter introduction

  • weight, is a one-dimensional tensor, the specific size is M, M is the number of labels of the sample, representing the weight assigned to the category
  • ignore_index, int type data, used to specify to ignore the index of a certain category. The default is -100, which means no categories are ignored.
  • reduction: Specify the calculation method of the loss function. Optional options include: 'none' (do not return the loss value of each sample), 'mean' (return the average loss value of each sample), 'sum' (return the total loss value of each sample).

Specific usage examples

import torch
import torch.nn as nn
batch_size = 32
class_num = 3
inputs = torch.rand(batch_size, class_num) # [32, 3]
target = torch.randint(0, 3, size=(batch_size,)) # [32]
softmax = nn.Softmax()
inputs = softmax(inputs)
loss_func = nn.CrossEntropyLoss()
predict = loss_func(inputs, target)
print(predict)
# 需要注意的是需要先定义损失函数/softmax函数,而且在设置size的时候需要额外多加入一个括号

Model input

  • inputs: The output of the model, the shape is (batch_size, class_num), class_num represents the number of categories. It can be seen as the probability value of each sample being classified into each category (here generally needs to be converted using softmax, etc.).
  • target: real label, shape (batch_size), where the value of each element is the category index to which the sample belongs.

Calculation method

Binary classification cross entropy loss function

L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ⁡ ( p i ) + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i ) ] L=\frac{1}{N} \sum_i L_i=\frac{1}{N} \sum_i-\left[y_i \cdot \log \left(p_i\right)+\left(1-y_i\right) \cdot \log \left(1-p_i\right)\right] L=N1iLi=N1i[yilog(pi)+(1yi)log(1pi)]

Parameter introduction

  • N, represents N samples
  • L i L_{i} Li, which is the value of the corresponding loss function for a certain sample
  • y i y_{i} yiis the label value of the sample, if it is, it will be 1, if not, it will be 0
  • p i p_{i} piIt is the probability distribution (numeric value) output by the model, located between 0-1

Multi-class cross entropy loss function

L = 1 N ∑ i L i = − 1 N ∑ i ∑ c = 1 M y i c log ⁡ ( p i c ) L=\frac{1}{N} \sum_i L_i=-\frac{1}{N} \sum_i \sum_{c=1}^M y_{i c} \log \left(p_{i c}\right) L=N1iLi=N1ic=1Myiclog(pic)

Parameter introduction

  • N, represents N samples
  • M, for M categories or categories
  • y i c y_{ic} yic, represents the label value of the i-th sample for the C-th category
  • p i c p_{ic} pic, represents the probability distribution of the i-th sample for the C-th category/(numeric value)

advantage

When using backpropagation and gradient descent optimization, the model depends on the learning rate and partial derivative value , and the learning rate can be set manually, so we start from the partial derivative. The larger the partial derivative, the worse the effect of the model, but it will also make the learning rate faster. Therefore, using the cross-entropy loss function, the learning speed will be faster and it will be easier to converge when the model effect is poor.

shortcoming

The task that focuses on is classification. It is easier to learn information between different categories. It is more concerned about the accuracy of the correct prediction probability and it is easy to ignore the differences and connections of other labels. The learned features are looser.

reference

Loss function | Cross entropy loss function (Zhihu)
Wikipedia introduction to cross entropy

Guess you like

Origin blog.csdn.net/xiaziqiqi/article/details/131510972