Several loss functions of pytorch CrossEntropyLoss, NLLLoss, BCELoss, BCEWithLogitsLoss, focal_loss, heatmap_loss

Several kinds of losses commonly used in classification problems are recorded and memoized, and will be continuously improved in the future.

nn.CrossEntropyLoss() cross entropy loss

Commonly used in multi-classification problems

CE = nn.CrossEntropyLoss()
loss = CE(input,target)

Input: (N, C) , dtype: float, N is the number of samples, usually batch_size
target: (N), dtype: long, is the category number, 0 ≤ targets[i] ≤ C−1
pytorch The cross entropy loss in is the combination of softmax and NLL loss, namely

nn.CrossEntropyLoss()(input,target) == nn.NLLLoss()(torch.log(nn.Softmax()(input)),target)

nn.NLLLoss()

NLL = nn.NLLLoss()
loss = NLL(input,target)

Input: (N, C) , dtype: float, N is the number of samples, usually batch_size
target: (N), dtype: long, is the category number, 0 ≤ targets[i] ≤ C−1

nn.BCELoss() binary cross entropy loss

Commonly used for binary or multi-label classification

BCE = nn.BCELoss()
loss = BCE(input,target)

Input: (N, x) , dtype: float, N is the number of samples, usually batch_size in batch calculation, x is the number of labels
target: (N, x), dtype: float, usually the one-hot code form of labels , note that it needs to be changed to float format

nn.BCEWithLogitsLoss()

Equivalent to BCE plus sigmoid

nn.BCEWithLogitsLoss()(input,target) == nn.BCELoss()(torch.sigmoid(input),target)

focal_loss

Focal loss is not available in pytorch. It is commonly used in target detection problems. See the figure in the paper for formulas and curves: The
insert image description here
formula of focal loss with balance parameters is as follows:
insert image description here
Code: (to be added later)

heatmap_loss

Heatmap_loss appears in the anchor-free target detection network centernet and conernet. It is further improved on the basis of focal loss, adding measures to reduce the loss of hotspot areas, so that the model output can be easily converged to the detection point attachment area. (Otherwise, if it has to converge to the detection point, it is too difficult and the convergence speed is slow) Note that it just adds an extra ( 1 − Y xyc ) β (1-Y_{xyc})^\beta
insert image description here
in the otherwise case(1Yx y c)In addition to β , it is focal loss

Guess you like

Origin blog.csdn.net/Brikie/article/details/116171023