mxnet中的SoftmaxCrossEntropyLoss损失函数

在看mxnet的Loss函数源码的时候,发现SoftmaxCrossEntropyLoss的实现很有意思,记录一下。

SoftmaxCrossEntropyLoss

概念性的东西,可以参考此文

p = softmax({pred})
L = -\sum_i \sum_j {label}_j \log p_{ij}

调用实例如下:

import mxnet.gluon.loss as gloss
import mxmet as mx
import numpy as np
...
a = np.zeros((2,3))
b = np.zeros((2,1))
pred = mx.nd.array(a)
label = mx.nd.array(b)
loss = gloss.SoftmaxCrossEntropyLoss()
print(loss(pred, label))

得到结果如下:

[1.9459101 1.9459101]
<NDArray 2 @cpu(0)>

乍一看很懵! 因为命名声明了两个全0的ndarray,怎么在计算后得到一个1.9459101呢?
首先我们来看这个值:

运行:print(math.log(1/7))

得到:-1.9459101490553135

原来这个奇怪的值是log(1/7)的相反数,有了一点意思,因为根据公式,softmax交叉熵损失函数等于label和概率p的乘积,且p是先做softmax,再求log,但是乘以label去哪了呢?
我们知道,多分类的label,实际上是这样一个数组:[0,0,0,0,1,0,...,0],除了一项为1,其他全为0,所以这里的乘法,实际上可以写成加法,只需要加上这一个不为0的值就可以了。

我们具体来看mxnet的实现:

class SoftmaxCrossEntropyLoss(Loss):
        def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
                 batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLoss, self).__init__(
            weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        if not self._from_logits:
            pred = F.log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = _reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = _apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

大部分情况下,会进到hybrid_forward下面的分支:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
所以我们有必要深入了解一下这个pick函数:

    def pick(self, *args, **kwargs):
        """Convenience fluent method for :py:func:`pick`.

        The arguments are the same as for :py:func:`pick`, with
        this array as data.
        """
        return op.pick(self, *args, **kwargs)

继续跟:

def pick(data=None, index=None, axis=_Null, keepdims=_Null, mode=_Null, out=None, name=None, **kwargs):
    return (0,)

pick使用实例:

x = [[ 1.,  2.],
     [ 3.,  4.],
     [ 5.,  6.]]
x = mx.nd.array(x)
ret = mx.nd.pick(x, mx.nd.array([0,0,0]), axis=1)
print(ret)
>>> [1. 3. 5.]
    <NDArray 3 @cpu(0)>
    
ret = mx.nd.pick(x, mx.nd.array([0,1]), axis=0)
print(ret)
>>> [1. 4.]
    <NDArray 2 @cpu(0)>

到此就可以理解整个过程了。

发布了42 篇原创文章 · 获赞 33 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/gaussrieman123/article/details/100142251