torch.nn.BCELoss

torch.nn.BCELoss se utiliza para calcular la pérdida de entropía cruzada para problemas de clasificación binaria o problemas de clasificación de etiquetas múltiples.

torch.nn.BCELoss debe usarse con la función Sigmoid.


Problema de clasificación binaria

Para el problema de clasificación binaria, si se usa la función Softmax, el número de neuronas en la última capa completamente conectada es 2; si se usa la función Sigmoid, el número de neuronas en la última capa completamente conectada es 1. Suponiendo que existe un problema de clasificación binaria de gatos y perros, el valor de salida de la función sigmoidea se expresa como la probabilidad de un gato.

import torch
import torch.nn as nn

'''
对于一个2分类问题,通常使用0和1作为标签。这里假设猫的标签为0,狗的标签为1
当batch_size = 3时,这一个batch的标签是一个形状为[batch_size]的tensor,即shape为[5]
'''

# 一个batch(batch_size=3)的标签:[狗, 猫, 猫]
label = torch.tensor([1, 0, 0], dtype=torch.float32)

'''
对于一个2分类问题,当训练时batch_size为3,
则深度网络对每一个batch的预测值是一个形状为[batch_size]的tensor,即shape为[3]
以深度网络对第一个样本(狗)的预测值 -0.8 为例,经过Sigmoid层后,得到0.3100, 表示深度网络认为第一个样本属于猫的概率分别为0.3100。
'''

predict = torch.tensor([-0.8, 0.7, 0.5])

# 转为概率值
sigmoid = nn.Sigmoid()
print(sigmoid(predict)) # tensor([0.3100, 0.6682, 0.6225])

# 当reduction='none'时,输出是对每一个样本预测的损失
loss_func = torch.nn.BCELoss(reduction='none')
loss = loss_func(Sigmoid(predict), label)
print(loss) # tensor([1.1711, 1.1032, 0.9741])
 
# 当reduction='sum'时,输出是对这一个batch预测的损失之和
loss_func = torch.nn.BCELoss(reduction='sum')
loss = loss_func(Sigmoid(predict), label)
print(loss) # tensor(3.2484)

# 当reduction='mean'时,输出是对这一个batch预测的平均损失
loss_func = torch.nn.BCELoss(reduction='mean')
loss = loss_func(Sigmoid(predict), label)
print(loss) # tensor(1.0828)

torch.nn.BCEWithLogitsLoss

También puede usar torch.nn.BCEWithLogitsLoss directamente, que tiene una operación Sigmoid integrada.

import torch
import torch.nn as nn

# 一个batch(batch_size=3)的标签:[狗, 猫, 猫]
label = torch.tensor([1, 0, 0], dtype=torch.float32)

# 深度网络对这一个batch的预测值
predict = torch.tensor([-0.8, 0.7, 0.5])

# 当reduction='none'时,输出是对每一个样本预测的损失
loss_func = torch.nn.BCEWithLogitsLoss(reduction='none')
print(loss_func(predict, label)) # tensor([1.1711, 1.1032, 0.9741])

# 当reduction='sum'时,输出是对这一个batch预测的损失之和
loss_func = torch.nn.BCEWithLogitsLoss(reduction='sum')
print(loss_func(predict, label)) # tensor(3.2484)
 
# 当reduction='mean'时,输出是对这一个batch预测的平均损失
loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')
print(loss_func(predict, label)) # tensor(1.0828)

Problema de clasificación de etiquetas múltiples

Suponiendo que en una imagen puede haber gatos, perros o personas, si no hay registro será 0, y si lo hay será 1.

Por ejemplo, si la etiqueta de una imagen es [1, 0, 0], significa que hay gatos en la imagen, pero no perros ni personas.

import torch
import torch.nn as nn

'''
对于一个3标签分类问题,
当batch_size = 2时,这一个batch的label是一个形状为[batch_size, label_classes]的tensor,即shape为[2, 3]
'''
# 一个batch(batch_size=2)的label
label = torch.tensor([[1, 0, 0],
                      [0, 1, 0]], dtype=torch.float32)

'''
对于一个3标签分类问题,当训练时batch_size为2,
则深度网络对每一个batch的预测值是一个形状为[batch_size, label_classes]的tensor,即shape为[2, 3]
以深度网络对第一个样本的预测值[-0.8, 0.7, 0.5]为例,经过Sigmoid层后,得到[0.7311, 0.5000, 0.5000], 
表示深度网络认为第一个样本有猫、狗、人的概率分别为0.7311、0.5000、0.5000
'''
predict = torch.tensor([[-0.8, 0.7, 0.5],
                        [0.2, -0.4, 0.6]])

# 当reduction='mean'时,输出是对这一个batch预测的平均损失
loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')
loss = loss_func(predict, label)
print(loss) # tensor(0.9995)

torch.nn.MultiLabelSoftMarginLoss

También puede usar torch.nn.MultiLabelSoftMarginLoss directamente

import torch
import torch.nn as nn

# 一个batch(batch_size=2)的label
label = torch.tensor([[1, 0, 0],
                      [0, 1, 0]], dtype=torch.float32)

predict = torch.tensor([[-0.8, 0.7, 0.5],
                        [0.2, -0.4, 0.6]])

# 当reduction='mean'时,输出是对这一个batch预测的平均损失
loss_func = torch.nn.MultiLabelSoftMarginLoss​​​​​​​(reduction='mean')
loss = loss_func(predict, label)
print(loss) # tensor(0.9995)

Supongo que te gusta

Origin blog.csdn.net/weixin_46566663/article/details/127911813
Recomendado
Clasificación