torch cross_entropy ignore_index

  • ignore_index:指定标签为什么的时候不参与损失的计算
ignore_index (int, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient. When :attr:`size_average` is
            ``True``, the loss is averaged over non-ignored targets.

在这里插入图片描述

示例如下,标签为0不参与损失的计算:

x = torch.Tensor([[0.9, 0.1], [0.8, 0.2], [0.7, 0.3]])
label = torch.LongTensor([0, 0, 0])
loss = F.cross_entropy(x, label, ignore_index=0)
loss2 = F.cross_entropy(x, label)
print(loss, loss2)
tensor(0.) tensor(0.4405)

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/127316136
今日推荐