PyTorch 中的交叉熵函数 CrossEntropyLoss 的计算过程

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/diyoosjtu/article/details/89141554

CrossEntropyLoss() 函数联合调用了 nn.LogSoftmax() 和 nn.NLLLoss()。

假设网络得到的输出为 h h ,它的维度大小为 B × C B\times C ,其中 B B 是 batch_size, C C 是分类的总数目。与之对应的训练数据的标签 y y 维度是 1 × B 1\times B y y 中元素的取值范围是 [ 0 , C 1 ] [0, C-1] ,即
0 y [ j ] C 1 j = 0 , 1 ,   , B 1 0\le y[j]\le C-1 \qquad j = 0, 1, \cdots, B-1

我们将CrossEntropyLoss() 函数的计算过程拆解为如下两个步骤:

  1. 对输出 h h ,执行LogSoftmax(dim=1),得到 s s ,维度仍然是 B × C B\times C
  2. s s 执行 log ( ) -\log() 操作,得到负对数概率 p p ,维度仍然是 B × C B\times C

则交叉熵的计算公式为:
(1) L = 1 B i = 0 B { log ( p [ i , y [ i ] ] ) } L = \frac{1}{B}\sum_{i=0}^B\left\{-\log(p[i,y[i]])\right\} \tag{1}

式(1)其实是从式(2)化简得来的:
(2) L = 1 B i = 0 B { j = 0 C 1 y [ i , j ] log ( p [ i , j ] ) } L = \frac{1}{B}\sum_{i=0}^B\left\{-\sum_{j=0}^{C-1}y[i, j]\log(p[i,j])\right\} \tag{2}

举例说明:

对于 C = 10 C=10 y = [ 7 , 7 , 2 , 4 ] y=[7, 7, 2, 4] 的情况,可知 B = 4 B=4 ,首先需要把 y y 扩展为 B × C B\times C 的矩阵:
y = [ 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 ] y = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0\\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
其中为1的元素位置,就是最终概率 p p 中需要取值的位置。

网络得到的输出
h = [ 0.1070 0.0083 0.0789 0.0341 0.0686 0.0088 0.0540 0.1017 0.0267 0.0925 0.0977 0.0053 0.0613 0.0576 0.0690 0.0104 0.0558 0.1133 0.0502 0.0775 0.1049 0.0091 0.0663 0.0611 0.0709 0.0168 0.0602 0.1072 0.0477 0.0878 0.1164 0.0018 0.0746 0.0531 0.0670 0.0142 0.0700 0.1005 0.0491 0.0939 ] h = \begin{bmatrix} -0.1070 & 0.0083 & -0.0789 & 0.0341 & 0.0686 & -0.0088 & 0.0540 & -0.1017 & 0.0267 & 0.0925\\ -0.0977 & -0.0053 & -0.0613 & 0.0576 & 0.0690 & -0.0104 & 0.0558 & -0.1133 & 0.0502 & 0.0775\\ -0.1049 & -0.0091 & -0.0663 & 0.0611 & 0.0709 & -0.0168 & 0.0602 & -0.1072 & 0.0477 & 0.0878\\ -0.1164 & -0.0018 & -0.0746 & 0.0531 & 0.0670 & -0.0142 & 0.0700 & -0.1005 & 0.0491 & 0.0939 \end{bmatrix}

s = [ 0.0898 0.1007 0.0923 0.1034 0.1070 0.0990 0.1054 0.0902 0.1026 0.1096 0.0903 0.0990 0.0936 0.1055 0.1067 0.0985 0.1053 0.0889 0.1047 0.1076 0.0896 0.0986 0.0931 0.1058 0.1068 0.0979 0.1057 0.0894 0.1044 0.1087 0.0886 0.0993 0.0923 0.1049 0.1064 0.0981 0.1067 0.0900 0.1045 0.1093 ] s = \begin{bmatrix} 0.0898 & 0.1007 & 0.0923 & 0.1034 & 0.1070 & 0.0990 & 0.1054 & 0.0902 & 0.1026 & 0.1096\\ 0.0903 & 0.0990 & 0.0936 & 0.1055 & 0.1067 & 0.0985 & 0.1053 & 0.0889 & 0.1047 & 0.1076\\ 0.0896 & 0.0986 & 0.0931 & 0.1058 & 0.1068 & 0.0979 & 0.1057 & 0.0894 & 0.1044 & 0.1087\\ 0.0886 & 0.0993 & 0.0923 & 0.1049 & 0.1064 & 0.0981 & 0.1067 & 0.0900 & 0.1045 & 0.1093 \end{bmatrix}

p = [ 2.4107 2.2954 2.3826 2.2696 2.2351 2.3125 2.2497 2.4054 2.2770 2.2112 2.4048 2.3124 2.3684 2.2495 2.2381 2.3175 2.2513 2.4204 2.2569 2.2296 2.4123 2.3165 2.3737 2.2463 2.2365 2.3242 2.2472 2.4146 2.2597 2.2196 2.4242 2.3096 2.3824 2.2547 2.2408 2.3220 2.2378 2.4083 2.2587 2.2139 ] p = \begin{bmatrix} 2.4107 & 2.2954 & 2.3826 & 2.2696 & 2.2351 & 2.3125 & 2.2497 & 2.4054 & 2.2770 & 2.2112\\ 2.4048 & 2.3124 & 2.3684 & 2.2495 & 2.2381 & 2.3175 & 2.2513 & 2.4204 & 2.2569 & 2.2296\\ 2.4123 & 2.3165 & 2.3737 & 2.2463 & 2.2365 & 2.3242 & 2.2472 & 2.4146 & 2.2597 & 2.2196\\ 2.4242 & 2.3096 & 2.3824 & 2.2547 & 2.2408 & 2.3220 & 2.2378 & 2.4083 & 2.2587 & 2.2139 \end{bmatrix}

因此,最终的交叉熵
L = 2.4054 + 2.4204 + 2.3737 + 2.2408 4 = 2.36 L = \frac{2.4054 + 2.4204 + 2.3737 + 2.2408 }{4} = 2.36

猜你喜欢

转载自blog.csdn.net/diyoosjtu/article/details/89141554