In-depth understanding of binary classification and multi-classification CrossEntropy Loss and Focal Loss
Binary Categorical Cross Entropy
In the case of dichotomy, there are only two cases where the model finally needs to predict the results. For each category, the probability of our prediction is ppp and1 − p 1-p1−p , then the expression is (log \loglog base iseee):
L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ( p i ) + ( 1 − y i ) ⋅ log ( 1 − p i ) ] L=\frac{1}{N} \sum_{i} L_i =\frac{1}{N} \sum_{i} -[y_i \cdot \log (p_i) +(1-y_i) \cdot \log (1-p_i)] L=N1i∑Li=N1i∑−[yi⋅log(pi)+(1−yi)⋅log(1−pi)]
where:
- y i y_i yi—— Indicates sample iiThe label of i , the positive class is 1, and the negative class is 0
- p i p_i pi—— Indicates sample iiThe probability that i is predicted to be a positive class
Since the binary cross-entropy is easy to understand, no examples are given here.
Multi-category cross entropy
Multi-category cross-entropy is an extension of binary cross-entropy. The calculation formula is slightly different from binary classification, but it is still relatively easy to understand. The specific formula is as follows: L = 1 N ∑ i L i = − 1 N
∑ i ∑ c = 1 M yic log ( pic ) L=\frac{1}{N} \sum_{i} L_i=-\frac{1}{N} \sum_{i} \sum_{c=1} ^M y_{ic} \log(p_{ic})L=N1i∑Li=−N1i∑c=1∑Myiclog(pic)
where:
- M M M - the number of categories
- y i c y_{ic} yic- sign function (0 or 1), if sample iiThe true category of i is equal to ccc takes 1, otherwise takes 0
- p i c p_{ic} pic——observation sample iii belongs to categoryccpredicted probability of c
for example
Prediction (already normalized by softmax) | reality |
---|---|
0.1 0.2 0.7 | 0 0 1 |
0.3 0.4 0.3 | 0 1 0 |
0.1 0.2 0.7 | 1 0 0 |
Now we use this expression to calculate the value of the loss function in the example above:
sample 1 loss = − ( 0 × log 0.1 + 0 × log 0.2 + 1 × log 0.7 ) = 0.35 , sample 2 loss = − ( 0 × log 0.1 + 1 × log 0.7 + 0 × log 0.2 ) = 0.35 , sample 3 loss = − ( 1 × log 0.3 + 0 × log 0.4 + 0 × log 0.4 ) = 1.20 , L = 0.3 5+ 0.35 + 1.2 3 = 0.63 \text{sample 1 loss}=-(0 \times \log 0.1+0 \times \log 0.2 + 1 \times \log 0.7)=0.35 ,\\ \text{sample 2 loss}= -(0 \times \log 0.1+1 \times \log 0.7 + 0 \times \log 0.2)=0.35 ,\\ \text{sample 3 loss}=-(1 \times \log 0.3+0 \times \log 0.4 + 0 \times \log 0.4)=1.20,\\ L=\frac{0.35+0.35+1.2}{3}=0.63sample 1 loss=−(0×log0.1+0×log0.2+1×log0.7)=0.35,sample 2 loss=−(0×log0.1+1×log0.7+0×log0.2)=0.35,sample 3 loss=−(1×log0.3+0×log0.4+0×log0.4)=1.20,L=30.35+0.35+1.2=0.63
In fact, it can be seen that the multi-class cross entropy only calculates the loss value of the probability corresponding to the correct label, and itsyic = 0 y_{ic}=0yic=0 , so the loss value corresponding to the wrong label is 0.
Pytorch's CrossEntropyLoss analysis
parameter setting
CrossEntropyLoss is on the Pytorch official website , we can see that the entire document has fully explained the function CrossEntropyLoss. So we briefly introduce its parameters and the format of the value passed in, especially for the case of multi-classification.
Common incoming parameters are as follows:
-
weight
: What is passed in is a list or tensor, and the value of the retrieved corresponding position is the weight of the class. Note that if it is a GPU environment, the value passed in must be a tensor, and it should be in the GPU. -
reduction
: The input is a string, there are three forms to choose from, namelymean
// , the default is . And as the literal meaning shows, it represents the form of averaging the loss values and summing the loss values. It is to calculate the loss value corresponding to each position, and return the shape corresponding to the label.sum
none
mean
mean
sum
none
More parameters are explained as shown in the figure below:
Instructions
The values passed in by CrossEntropyLoss are two, namely input
and target
. There is only one output Output
.
-
input
The shape is ( N , C ) / ( N , C , d 1 , d 1 , … ) (N,C)/(N,C,d_1,d_1,\ldots)(N,C)/(N,C,d1,d1,… ) , the former corresponds to the two-dimensional case, and the latter corresponds to the high-dimensional case.It is worth noting that CCC is indim=1
the position, and many people may think that it should be the last dimension by default in the case of high dimensionsdim=-1
. -
target
The shape is ( N ) / ( N , d 1 , d 1 , … ) (N)/(N,d_1,d_1,\ldots)(N)/(N,d1,d1,… ) , the former corresponds to the two-dimensional case, and the latter corresponds to the high-dimensional case. Note thattarget
the value corresponds to the index corresponding to the category, not in the form of one-hot. -
Output
The shape oftarget
is consistent with the shape of .
More parameters are explained as shown in the figure below:
Calculation of 5-category cross-entropy loss corresponding to the two-dimensional case (official website example):
>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()
The corresponding cross-entropy calculation in the high-dimensional case:
input = torch.randn(2,3,5,5,4)#最后一个维度对应的是类别
target = torch.empty(2,3,5,5, dtype=torch.long).random_(4) #四分类
loss_fn=CrossEntropyLoss(reduction='sum')
_input=torch.permute(input,dims=(0,-1,1,2,3))
loss=loss_fn(_input,target)#输入的类别一定是在dim=1的位置上
print(loss)
# 当然也可以将输入先转为2维的形式在计算,结果是一样的
_input=input.view(-1,4)
_target=target.view(-1)
loss=loss_fn(_input,_target)
print(loss)
inner principle
In Pytorch, the sum function CrossEntropyLoss()
is merged, which means that its internal implementation is based on the sum function.logSoftmax()
NLLLoss()
logSoftmax()
NLLLoss()
input=torch.rand(3,5)
target=torch.empty(3,dtype=torch.long).random_(5)
loss_fn=CrossEntropyLoss(reduction='sum')
loss=loss_fn(input,target)
print(loss)
_input=torch.nn.LogSoftmax(dim=1)(input)
loss=torch.nn.NLLLoss(reduction='sum')(_input,target)
print(loss)
In fact, it is the same as what is said on the official website. CrossEntropyLoss()
It is to calculate the output softmax()
, take log()
the logarithm of the result, and finally use NLLLoss()
the index value of the corresponding position.
Focal Loss principle and implementation
Focal Loss comes from the paper Focal Loss for Dense Object Detection , which is used to solve the problem of category sample imbalance and difficult sample mining. Its formula is very simple:
FL ( pt ) = − α t ( 1 − pt ) γ log ( pt ) FL(p_t)=- \alpha_t (1-p_t) ^{\gamma} \log (p_t)FL(pt)=− at(1−pt)clog(pt)
p t p_t ptis the class probability value of the outcome predicted by the model. − log ( pt ) - \log (p_t)−log(pt) is consistent with the cross-entropy loss function, so the pt p_tcorresponding to the current sample categoryptIf it is smaller, it means that the prediction is less accurate, then ( 1 − pt ) γ (1-p_t)^{\gamma}(1−pt)The γ item will increase, and this item is also used as the coefficient of difficult samples. The more inaccurate the prediction, the more Focal Loss tends to treat this sample as a difficult sample, and the larger the coefficient, the purpose is to make difficult samples more accurate. Loss and gradient contribute more.
The previous α t \alpha_tatis the category weight coefficient. If you have a class-imbalanced data set, then you definitely want to assign a high weight to the loss contribution of the small number of classes, this α t \alpha_tatJust play such a role. Therefore, α t \alpha_tatIt should be a vector, the length of the vector is equal to the number of categories, and it is used to store the weight of each category. In general α t \alpha_tatThe value in is the reciprocal of the number of samples in each category, which is equivalent to the gap in the number of balanced samples.
Here is an implementation of a two-dimensional/high-dimensional Focal Loss:
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=torch.tensor([0.2, 0.3, 0.5,1])):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
logpt = nn.functional.log_softmax(input, dim=1) #计算softmax后在计算log
pt = torch.exp(logpt) #对log_softmax去exp,把log取消就是概率
alpha=self.alpha[target].unsqueeze(dim=1) # 去取真实索引类别对应的alpha
logpt = alpha*(1 - pt) ** self.gamma * logpt #focal loss计算公式
loss = nn.functional.nll_loss(logpt, target,reduction='sum') # 最后选择对应位置的元素
return loss