损失函数中出镜率还是特别高的,每次用它的时候需要把数据结构调整对了,交叉熵就是预测分类的所属概率,比如这个图片是属于哪一类,属于狗?猫?马?什么的,用网络训完之后会得到一个每个分类的预测概率。
看下面代码
import torch
import torch.nn as nn
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# input
# tensor([[ 0.1743, 0.5794, -0.9481, -1.4114, 0.1286],
# [ 0.4107, 1.2454, 1.0459, -0.1452, 0.8778],
# [ 0.5576, -0.4521, 1.3982, 0.6421, 0.8372]],
# requires_grad=True)
# target
# tensor([2, 0, 1])
可以看出来input
输入的尺寸是torch.Size([3, 5])
而target
的尺寸是torch.Size([3])
,如果按照图片分类的话可以这样理解,一次性输入三张图片,每张预测五个结果(允许里面有负数,也不一定非要在0~1之间,里面做了LogSoftmax
操作),target
则是准确分类的下标,因为概率只有一个是百分之百,其余的都是零,出现一个就好了。这样的话就可以做CrossEntropyLoss()
的操作了