torch.nn.CrossEntropyLoss使用的一些注意点

火大火大,耗费自己好多时间,就因为torch.nn.CrossEntropyLoss没怎么仔细看doc,记录一下加深记忆。

CrossEntropyLoss算的是prediction和target之间的交叉熵损失,一般都使用在classification task。一般而言,在计算交叉熵之前需要把model的prediction先用softmax进行normalize随后在用极大似然估计来计算loss,具体公式可以自己百度。
但是torch的CrossEntropyLoss有点不一样,它已经集成了log_softmax:
在这里插入图片描述
所以doc里面明确说了,喂给CrossEntropyLoss的input要是原始的:
在这里插入图片描述
燃鹅我的代码应该改成:
在这里插入图片描述
为给CrossEntropyLoss的input就直接是model的out,而不是经过了softmax的result.(由于我需要获得model预测的label,所以我还需要额外用softmax获得label的概率分布result)

火大,所以看doc,tf和torch还是不一样,会confusing

总结一下使用torch.nn.CrossEntropyLoss的注意点:

  1. input必须是raw的,因为CrossEntropyLoss = log_softmax + NllLoss,无需手动再进行softmax
  2. label必须是long类型,而且value在[0,class-1]之间
  3. 注意!!torch.nn.CrossEntropyLoss其实也是一个module,如果你在GPU上训练,那么它也要to("cuda")

参考资料:
torch1.3.0 doc

猜你喜欢

转载自blog.csdn.net/weixin_43301333/article/details/113819686
今日推荐