Understand the meaning of Pytorch inside nn.CrossEntropyLoss
Nn.CrossEntropyLoss said first parameter, if the output of the neural network output is a (batch_size, num_class, h, w) of Tensor (wherein, num_class representative of the number of categories of classification, h is image height, w is the width of the image), then nn.CrossEntropyLoss label shape is desired (batch_size, h, w), for each batch, the data of the representative category label each pixel belongs, if it is a binary classification, the value of the label only 0 or 1, if the three classification, the label value may be 0, 1, and so on.
Cross entropy is a measure of the similarity between the two distributions, and therefore the neural network can estimate the proximity of the actual output and the desired output. Suppose there are two distributionsp(x),Q ( X ) , the cross-entropy of the two C E H=−x∈χ∑p(x)log(q(x))
After classification, a given label and the sample, the sample can only belong to one species, assuming that the sample belongs to the speciesk , thenp(x=k)=1,p(x=k)=0 , so the output of the cross-entropy of the sample and label can be simplified to C E H=−log(q(x=k))
Neural network output is generally equal to the number of categories vectors, in order to convert the vector probability distribution, i.e.,q(x=K ) in the form must be used softmax function converts the output of the neural network, but the form of the cross-entropy function added softmax function is as follows, the formula is the formula nn.CrossEntropyLoss loss(x,k)=−log(∑jexp(x[j])exp(x[k]))
import torch
import numpy as np
import torch.nn as nn
import math
a = torch.randn((4,3,8,8))
b = np.random.randint(0,3,(4,8,8))
b = torch.from_numpy(b)
loss_fn = nn.CrossEntropyLoss()
b = b.long()
loss = loss_fn(a, b)
loss
# tensor(1.3822)#验证softmax2d就是对每一个N维度沿着C维度做softmax
m = nn.Softmax2d()
output = m(a)#验证softmax2d就是对每一个N维度沿着C维度做softmax
a01 = math.exp(a[0,0,0,0])
a02 = math.exp(a[0,1,0,0])
aa = a01 + a02
print(a01/aa)print(a02/aa)print(output[0,0,0,0])print(output[0,1,0,0])
loss =0for batch inrange(4):for i inrange(8):for j inrange(8):if b[batch, i, j]==1:
loss = loss - math.log(output[batch,1, i, j])if b[batch, i, j]==0:
loss = loss - math.log(output[batch,0, i, j])if b[batch, i, j]==2:
loss = loss - math.log(output[batch,2, i, j])print(loss/64/4)#将总的loss对总样本数取平均值,样本数为图像中像素数量8*8再*batch_size即为8*8*4# 1.3822217100148755
The results can be seen, manual calculation of loss equal to the calculated loss loss_fn