torch BCEWithLogitsLoss

CLASStorch.nn.BCEWithLogitsLoss(weight=Nonesize_average=Nonereduce=Nonereduction='mean'pos_weight=None

1. Parameter description:

2. Input logits and label and output the shape of loss:

3. Description

This loss combines a Sigmoid layer and BCELoss in one class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss.

The loss of the two-category unreduced can be described as:

L=\begin{Bmatrix} l_{1},...,l_{N} \end{Bmatrix}^{\mathbf{T}}l_{n}=-w_{n}[y_{n}\cdot log\sigma (x_{n})+(1-y_{n})\cdot log(1-\sigma (x_{n}))]

where N is the batch size. x_n is the logits of the nth sample.

If reduction is 'mean', the sum of the output will be divided by the number of elements in the output; if reduction is 'sum', the output will be summed. If the reduction is none, the size of the loss output is the same as the input and label shape, both are [batch_size, #class].

pos_weight specifies the weight of the positive samples of each class, and its shape is the same as the input shape and label shape, both of which are [batch_size, #class]. For example, if a dataset contains 100 positive and 300 negative examples of a single class , then pos_weight  for the class should be equal to 300/100=3. The loss would act as if the dataset contains 3×100=300 positive examples.

In multi-label classification , loss can be described as:

L_c=\begin{Bmatrix} l_{1,c},...,l_{N,c} \end{Bmatrix}^{\mathbf{T}}l_{n,c}=-w_{n,c}[p_cy_{n,c}\cdot log\sigma (x_{n,c})+(1-y_{n,c})\cdot log(1-\sigma (x_{n,c}))]

其中c是类别,p_c​ is the weight of the positive answer for the class c。p_c​ > 1 increases the recall, p_c < 1increases the precision.

4. Example:

target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(1.5))

reference:

BCEWithLogitsLoss — PyTorch 1.12 documentation

Guess you like

Origin blog.csdn.net/qq_41021141/article/details/126001325