【pytorch】pytorch自我实现cross entropy

修改自:https://blog.csdn.net/WYXHAHAHA123/article/details/88342571

特点

  • 支持ignore_index
  • 和pytorch内置函数进行结果对比

code

import torch
import torch.nn.functional as F
'''
实现cross entropy损失函数计算的三种方式
'''
input = torch.randn(10, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4, 2, 3, 4, 0, 0, 0, 1]).long()
output = F.nll_loss(F.log_softmax(input,dim=1), target)
print('loss1',output)
'''
F.nll_loss
negative log likelihood loss. 负的对数似然函数
'''
loss2=F.cross_entropy(input,target, ignore_index=0)
print('loss2',loss2)
'''
计算cross entropy损失函数的步骤:
先对输入的input(这时候的input tensor必须是2-dimension:shape [num_samples,num_classes])
进行  F.softmax(input,dim=1)
得到当前sample样本在每个类别上的分类概率值之后,再对概率值取对数  torch.log(F.softmax(input,dim=1))
最后引入target信息,将每个样本对于正确类别得到的概率值取对数之后的数值取出来,得到交叉熵损失函数值
'''
probability=F.softmax(input,dim=1)#shape [num_samples,num_classes]
log_P=torch.log(probability)
'''对输入的target标签进行 one-hot编码,使用_scatter方法'''
a=torch.unsqueeze(target,dim=0)
# print(a.shape,probability.shape)
one_hot=torch.zeros(probability.shape).scatter_(1,torch.unsqueeze(target,dim=1),1)
loss3=-one_hot*log_P
loss3 = loss3.sum(dim=1)
mask = target.ne(0)
loss3 = loss3.masked_select(mask)
loss3=loss3.sum()
loss3/=mask.sum()
print('loss3',loss3)
'''loss4为加上标签平滑之后的分类交叉熵'''
loss4=F.nll_loss(torch.log(F.softmax(input,dim=1)+1e-3),target)
print('loss4',loss4)
'''
loss1 tensor(2.1232, grad_fn=<NllLossBackward>)
loss2 tensor(2.1232, grad_fn=<NllLossBackward>)
loss3 tensor(2.1232, grad_fn=<DivBackward0>)
loss4 tensor(2.1133, grad_fn=<NllLossBackward>)
'''

猜你喜欢

转载自blog.csdn.net/u011622208/article/details/105746191