Pytorch 中的交叉熵CrossEntropyLoss

损失函数中出镜率还是特别高的,每次用它的时候需要把数据结构调整对了,交叉熵就是预测分类的所属概率,比如这个图片是属于哪一类,属于狗?猫?马?什么的,用网络训完之后会得到一个每个分类的预测概率。

看下面代码

import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

# input 
# tensor([[ 0.1743,  0.5794, -0.9481, -1.4114,  0.1286],
#         [ 0.4107,  1.2454,  1.0459, -0.1452,  0.8778],
#         [ 0.5576, -0.4521,  1.3982,  0.6421,  0.8372]],
# requires_grad=True)

# target 
# tensor([2, 0, 1])

可以看出来input输入的尺寸是torch.Size([3, 5])target 的尺寸是torch.Size([3]),如果按照图片分类的话可以这样理解,一次性输入三张图片,每张预测五个结果(允许里面有负数,也不一定非要在0~1之间,里面做了LogSoftmax操作),target则是准确分类的下标,因为概率只有一个是百分之百,其余的都是零,出现一个就好了。这样的话就可以做CrossEntropyLoss()的操作了

发布了163 篇原创文章 · 获赞 117 · 访问量 21万+

猜你喜欢

转载自blog.csdn.net/u010095372/article/details/102947041