In-depth understanding of binary classification and multi-classification CrossEntropy Loss and Focal Loss

In-depth understanding of binary classification and multi-classification CrossEntropy Loss and Focal Loss

Binary Categorical Cross Entropy

In the case of dichotomy, there are only two cases where the model finally needs to predict the results. For each category, the probability of our prediction is ppp and1 − p 1-p1p , then the expression is (log ⁡ \loglog base iseee):
L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ⁡ ( p i ) + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i ) ] L=\frac{1}{N} \sum_{i} L_i =\frac{1}{N} \sum_{i} -[y_i \cdot \log (p_i) +(1-y_i) \cdot \log (1-p_i)] L=N1iLi=N1i[yilog(pi)+(1yi)log(1pi)]
where:

  • y i y_i yi—— Indicates sample iiThe label of i , the positive class is 1, and the negative class is 0
  • p i p_i pi—— Indicates sample iiThe probability that i is predicted to be a positive class

Since the binary cross-entropy is easy to understand, no examples are given here.

Multi-category cross entropy

Multi-category cross-entropy is an extension of binary cross-entropy. The calculation formula is slightly different from binary classification, but it is still relatively easy to understand. The specific formula is as follows: L = 1 N ∑ i L i = − 1 N
∑ i ∑ c = 1 M yic log ⁡ ( pic ) L=\frac{1}{N} \sum_{i} L_i=-\frac{1}{N} \sum_{i} \sum_{c=1} ^M y_{ic} \log(p_{ic})L=N1iLi=N1ic=1Myiclog(pic)
where:

  • M M M - the number of categories
  • y i c y_{ic} yic- sign function (0 or 1), if sample iiThe true category of i is equal to ccc takes 1, otherwise takes 0
  • p i c p_{ic} pic——observation sample iii belongs to categoryccpredicted probability of c

for example

Prediction (already normalized by softmax) reality
0.1 0.2 0.7 0 0 1
0.3 0.4 0.3 0 1 0
0.1 0.2 0.7 1 0 0

Now we use this expression to calculate the value of the loss function in the example above:
sample 1 loss = − ( 0 × log ⁡ 0.1 + 0 × log ⁡ 0.2 + 1 × log ⁡ 0.7 ) = 0.35 , sample 2 loss = − ( 0 × log ⁡ 0.1 + 1 × log ⁡ 0.7 + 0 × log ⁡ 0.2 ) = 0.35 , sample 3 loss = − ( 1 × log ⁡ 0.3 + 0 × log ⁡ 0.4 + 0 × log ⁡ 0.4 ) = 1.20 , L = 0.3 5+ 0.35 + 1.2 3 = 0.63 \text{sample 1 loss}=-(0 \times \log 0.1+0 \times \log 0.2 + 1 \times \log 0.7)=0.35 ,\\ \text{sample 2 loss}= -(0 \times \log 0.1+1 \times \log 0.7 + 0 \times \log 0.2)=0.35 ,\\ \text{sample 3 loss}=-(1 \times \log 0.3+0 \times \log 0.4 + 0 \times \log 0.4)=1.20,\\ L=\frac{0.35+0.35+1.2}{3}=0.63sample 1 loss=(0×log0.1+0×log0.2+1×log0.7)=0.35,sample 2 loss=(0×log0.1+1×log0.7+0×log0.2)=0.35,sample 3 loss=(1×log0.3+0×log0.4+0×log0.4)=1.20,L=30.35+0.35+1.2=0.63
In fact, it can be seen that the multi-class cross entropy only calculates the loss value of the probability corresponding to the correct label, and itsyic = 0 y_{ic}=0yic=0 , so the loss value corresponding to the wrong label is 0.

Pytorch's CrossEntropyLoss analysis

parameter setting

CrossEntropyLoss is on the Pytorch official website , we can see that the entire document has fully explained the function CrossEntropyLoss. So we briefly introduce its parameters and the format of the value passed in, especially for the case of multi-classification.

Common incoming parameters are as follows:

  • weight: What is passed in is a list or tensor, and the value of the retrieved corresponding position is the weight of the class. Note that if it is a GPU environment, the value passed in must be a tensor, and it should be in the GPU.

  • reduction: The input is a string, there are three forms to choose from, namely mean// , the default is . And as the literal meaning shows, it represents the form of averaging the loss values ​​and summing the loss values. It is to calculate the loss value corresponding to each position, and return the shape corresponding to the label.sumnonemeanmeansumnone

More parameters are explained as shown in the figure below:

Instructions

The values ​​passed in by CrossEntropyLoss are two, namely inputand target. There is only one output Output.

  • inputThe shape is ( N , C ) / ( N , C , d 1 , d 1 , … ) (N,C)/(N,C,d_1,d_1,\ldots)(N,C)/(N,C,d1,d1,) , the former corresponds to the two-dimensional case, and the latter corresponds to the high-dimensional case.It is worth noting that CCC is indim=1the position, and many people may think that it should be the last dimension by default in the case of high dimensionsdim=-1.

  • targetThe shape is ( N ) / ( N , d 1 , d 1 , … ) (N)/(N,d_1,d_1,\ldots)(N)/(N,d1,d1,) , the former corresponds to the two-dimensional case, and the latter corresponds to the high-dimensional case. Note thattargetthe value corresponds to the index corresponding to the category, not in the form of one-hot.

  • OutputThe shape of targetis consistent with the shape of .

More parameters are explained as shown in the figure below:

Calculation of 5-category cross-entropy loss corresponding to the two-dimensional case (official website example):

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

The corresponding cross-entropy calculation in the high-dimensional case:

input = torch.randn(2,3,5,5,4)#最后一个维度对应的是类别
target = torch.empty(2,3,5,5, dtype=torch.long).random_(4) #四分类
loss_fn=CrossEntropyLoss(reduction='sum')
_input=torch.permute(input,dims=(0,-1,1,2,3))
loss=loss_fn(_input,target)#输入的类别一定是在dim=1的位置上
print(loss)
# 当然也可以将输入先转为2维的形式在计算,结果是一样的
_input=input.view(-1,4)
_target=target.view(-1)
loss=loss_fn(_input,_target)
print(loss)

inner principle

In Pytorch, the sum function CrossEntropyLoss()is merged, which means that its internal implementation is based on the sum function.logSoftmax()NLLLoss()logSoftmax()NLLLoss()

input=torch.rand(3,5)
target=torch.empty(3,dtype=torch.long).random_(5)
loss_fn=CrossEntropyLoss(reduction='sum')
loss=loss_fn(input,target)
print(loss)
_input=torch.nn.LogSoftmax(dim=1)(input)
loss=torch.nn.NLLLoss(reduction='sum')(_input,target)
print(loss)

In fact, it is the same as what is said on the official website. CrossEntropyLoss()It is to calculate the output softmax(), take log()the logarithm of the result, and finally use NLLLoss()the index value of the corresponding position.

Focal Loss principle and implementation

Focal Loss comes from the paper Focal Loss for Dense Object Detection , which is used to solve the problem of category sample imbalance and difficult sample mining. Its formula is very simple:
FL ( pt ) = − α t ( 1 − pt ) γ log ⁡ ( pt ) FL(p_t)=- \alpha_t (1-p_t) ^{\gamma} \log (p_t)FL(pt)=at(1pt)clog(pt)
p t p_t ptis the class probability value of the outcome predicted by the model. − log ⁡ ( pt ) - \log (p_t)log(pt) is consistent with the cross-entropy loss function, so the pt p_tcorresponding to the current sample categoryptIf it is smaller, it means that the prediction is less accurate, then ( 1 − pt ) γ (1-p_t)^{\gamma}(1pt)The γ item will increase, and this item is also used as the coefficient of difficult samples. The more inaccurate the prediction, the more Focal Loss tends to treat this sample as a difficult sample, and the larger the coefficient, the purpose is to make difficult samples more accurate. Loss and gradient contribute more.

The previous α t \alpha_tatis the category weight coefficient. If you have a class-imbalanced data set, then you definitely want to assign a high weight to the loss contribution of the small number of classes, this α t \alpha_tatJust play such a role. Therefore, α t \alpha_tatIt should be a vector, the length of the vector is equal to the number of categories, and it is used to store the weight of each category. In general α t \alpha_tatThe value in is the reciprocal of the number of samples in each category, which is equivalent to the gap in the number of balanced samples.

Here is an implementation of a two-dimensional/high-dimensional Focal Loss:

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=torch.tensor([0.2, 0.3, 0.5,1])):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, input, target):
        logpt = nn.functional.log_softmax(input, dim=1) #计算softmax后在计算log
        pt = torch.exp(logpt) #对log_softmax去exp,把log取消就是概率
        alpha=self.alpha[target].unsqueeze(dim=1) # 去取真实索引类别对应的alpha
        logpt = alpha*(1 - pt) ** self.gamma * logpt #focal loss计算公式
        loss = nn.functional.nll_loss(logpt, target,reduction='sum') # 最后选择对应位置的元素
        return loss

References

CrossEntropy official website details.

Interpretation of the CrossEntropyLoss() function case in Pytorch and calculation of Loss in combination with one-hot encoding

Detailed explanation of PyTorch's implementation of multi-category Focal Loss - concise implementation with alpha

a

Guess you like

Origin blog.csdn.net/qq_45041871/article/details/130565823