pytorch 实现cross entropy损失函数计算的三种方式

import torch
import torch.nn.functional as F
'''
实现cross entropy损失函数计算的三种方式
'''
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4]).long()

output = F.nll_loss(F.log_softmax(input,dim=1), target)
print('loss1',output)
'''
F.nll_loss
negative log likelihood loss. 负的对数似然函数
'''

loss2=F.cross_entropy(input,target)
print('loss2',loss2)

'''
计算cross entropy损失函数的步骤:
先对输入的input(这时候的input tensor必须是2-dimension:shape [num_samples,num_classes])
进行  F.softmax(input,dim=1)
得到当前sample样本在每个类别上的分类概率值之后,再对概率值取对数  torch.log(F.softmax(input,dim=1))
最后引入target信息,将每个样本对于正确类别得到的概率值取对数之后的数值取出来,得到交叉熵损失函数值
'''
probability=F.softmax(input,dim=1)#shape [num_samples,num_classes]
log_P=torch.log(probability)
'''对输入的target标签进行 one-hot编码,使用_scatter方法'''
a=torch.unsqueeze(target,dim=0)
# print(a.shape,probability.shape)
one_hot=torch.zeros(probability.shape).scatter_(1,torch.unsqueeze(target,dim=1),1)
loss3=-one_hot*log_P
loss3=loss3.sum()
loss3/=probability.shape[0]
print('loss3',loss3)

'''loss4为加上标签平滑之后的分类交叉熵'''
loss4=F.nll_loss(torch.log(F.softmax(input,dim=1)+1e-3),target)
print('loss4',loss4)
'''
loss1 tensor(2.1232, grad_fn=<NllLossBackward>)
loss2 tensor(2.1232, grad_fn=<NllLossBackward>)
loss3 tensor(2.1232, grad_fn=<DivBackward0>)
loss4 tensor(2.1133, grad_fn=<NllLossBackward>)
'''


之前使用F.cross_entropy计算交叉熵损失函数时,一般没有注意到weights参数究竟是怎么操作的,先考虑weights=None的情况,如果reduce=False,则函数会返回shape为样本点数的tensor, 如果reduce=True,则函数会将对于每个样本点计算出来的损失函数值相加,再除以样本点的个数!!!但是这里一定要注意,只有在weights=None的情况下分母才等于样本点的个数,它表示对于训练batch size的每个样本点是相同权重——都是1。但是如果加入了weights,则分母就不再简单的等于样本点的个数了,而是要将类别权重也考虑进去。

import torch
import numpy as np
import torch.nn.functional as F
pred=torch.rand((2,6,5,5))
y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
c=pred.shape[1]
pred=pred.permute(0,2,3,1).contiguous().view(-1,c)
y=y.view(-1).long()

weights=torch.tensor((0.2,0.1,0.4,0.15,0.05,0.1))

loss1=F.cross_entropy(pred,y,weights,reduce=True)

loss2=F.cross_entropy(pred,y,weights,reduce=False).sum()
samples_weights=weights[y]
loss2/=torch.sum(samples_weights)

print('loss1',loss1)
print('loss2',loss2)

'''
loss1 tensor(1.8918)
loss2 tensor(1.8918)
'''

猜你喜欢

转载自blog.csdn.net/WYXHAHAHA123/article/details/88342571
今日推荐