Pytorch代码踩坑

1、CrossEntropy的weight的问题

在使用torch.nn.CrossEntropyLoss时,其中有个参数是weight,官方给出的文档中显示,weight是a manual rescaling weight given to each class,也就是一个缩放的尺度。但我发现,当另一个参数的reduction为mean的时候,会出现一些问题。

import torch
import torch.nn

pred = torch.tensor([[1,5],[2,2]]).float()
label = torch.tensor([0,1]0.long()

c = nn.CrossEntropyLoss(reduction='none')
print(c(pred,label))

c = nn.CrossEntropyLoss()
print(c(pred,label))

c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]))
print(c(pred,label))

c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]),reduction='none')
print(c(pred,label))

c = nn.CrossEntropyLoss(weight=torch.tensor([2.0,2.0]),reduction='none')
print(c(pred,label).mean())

输出为:

tensor([4.0181, 0.6931])
tensor(2.3556)
tensor(2.3556)
tenwor([8.0363, 1.3863])
tensor([4.7113])

如上可以看到,当有weight和reduction='mean’的时候,不是说weight不起作用,而是loss的值不单单是使用weight做scale变化。要是想单纯的做scale,需要先将reduction='none’然后再取mean

猜你喜欢

转载自blog.csdn.net/wqwqqwqw1231/article/details/106195522