关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

目录

1. 交叉熵损失 CrossEntropyLoss

2. ignore_index 参数

3. weight 参数

4. 例子


1. 交叉熵损失 CrossEntropyLoss

CrossEntropyLoss 交叉熵损失可函数以用于分类或者分割任务中,这里主要介绍分割任务

建立如下的数据,pred是预测样本,label是真实标签

分割中,使用交叉熵损失的话,需要保证label的维度比pred维度少1,也就是没有channel维度。并且,label的类型是int

正常计算损失结果为:

手动计算一下,pred的softmax为

所以,loss = -(ln0.69+ln0.3543+ln0.5987)/3 = -(ln0.1464) / 3 = 0.6406 

后面的是计算产生的误差,这里用数学方法简化计算了

one-hot 编码,只计算label的 ln 预测值

2. ignore_index 参数

在分割任务中,经常有像素点是认为不感兴趣的,所以这里ignore_index可以将那些不感兴趣的像素点排除

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
loss = nn.CrossEntropyLoss(ignore_index=1)
out = loss(pred,label)
print(out)      # tensor(0.4421)

这里将label = 1的像素点排除,手动计算一下

loss = (-ln0.69-ln0.5987) / 2 = 0.4421 

这里将label = 1的忽略了,下面是pred的softmax值

3. weight 参数

当涉及到样本的个数不平衡的时候,可以将样本少的label,w加大点

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
w = torch.FloatTensor([1,2])
loss = nn.CrossEntropyLoss(weight=w)
out = loss(pred,label)
print(out)      # tensor(0.7398)

计算方法是:

loss =- ( 1*ln0.69 + 2*ln0.3543+1*ln0.5987) / 4 = (0.3711 + 2.0741+ 0.5130) / 4= 0.7396

可以发现答案是类似的,这里保留了四位小数进行计算,所以有误差

因为,label = 1有一个,label = 0 有两个,所以1的样本较少,这里就对label = 1设置权重大点。可以发现,计算出来的loss确实比不加loss的大,下图为不加w的

如果将w改成[2,1]的话,loss会更低,不利于loss的下降

所以,在样本不均衡的情况下,加label少的样本,w加大,可以将loss变大,从而梯度下降的时候可以更好的弥补样本不平衡的问题

注意:w的类型是float

4. 例子

测试代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1,0.2],[0.8, 0.2,0.1],[0.7, 0.3,0.5],[0.1,0.5,0.6]])
label = torch.LongTensor([2, 1, 0,1])

s = F.softmax(pred,dim=1)
print(s)

w = torch.FloatTensor([2,1,2])
loss = nn.CrossEntropyLoss(weight=w,ignore_index=2)
out = loss(pred,label)
print(out)      # tensor(1.0401)

其中,pred的softmax如下:

label 为:2 1 0 1

可以发现,label 是 0 1 2 三类,这里将label = 2的忽略,并且对0 1 2施加的权重为 2 1 2

所以手动计算的公式为,这里精确到六位小数

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

label = 2 的损失 = - ln0.2552 = 1.365708

这里忽略了label = 2,所以还剩:

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

并且对0 1 进行加权2 1

所以总的loss = (0.911801 *2 + 1.315650*1+1.020818*1) /(2+1+1) = 4.16007/4=1.0400175

可以发现结果是一样的,这里最后是精度问题

猜你喜欢

转载自blog.csdn.net/qq_44886601/article/details/130124828
今日推荐