在看mxnet的Loss函数源码的时候,发现SoftmaxCrossEntropyLoss的实现很有意思,记录一下。
SoftmaxCrossEntropyLoss
概念性的东西,可以参考此文
p = softmax({pred})
L = -\sum_i \sum_j {label}_j \log p_{ij}
调用实例如下:
import mxnet.gluon.loss as gloss
import mxmet as mx
import numpy as np
...
a = np.zeros((2,3))
b = np.zeros((2,1))
pred = mx.nd.array(a)
label = mx.nd.array(b)
loss = gloss.SoftmaxCrossEntropyLoss()
print(loss(pred, label))
得到结果如下:
[1.9459101 1.9459101]
<NDArray 2 @cpu(0)>
乍一看很懵! 因为命名声明了两个全0的ndarray,怎么在计算后得到一个1.9459101呢?
首先我们来看这个值:
运行:print(math.log(1/7))
得到:-1.9459101490553135
原来这个奇怪的值是log(1/7)的相反数,有了一点意思,因为根据公式,softmax交叉熵损失函数等于label和概率p的乘积,且p是先做softmax,再求log,但是乘以label去哪了呢?
我们知道,多分类的label,实际上是这样一个数组:[0,0,0,0,1,0,...,0]
,除了一项为1,其他全为0,所以这里的乘法,实际上可以写成加法,只需要加上这一个不为0的值就可以了。
我们具体来看mxnet的实现:
class SoftmaxCrossEntropyLoss(Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(SoftmaxCrossEntropyLoss, self).__init__(
weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
大部分情况下,会进到hybrid_forward下面的分支:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
,
所以我们有必要深入了解一下这个pick函数:
def pick(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`pick`.
The arguments are the same as for :py:func:`pick`, with
this array as data.
"""
return op.pick(self, *args, **kwargs)
继续跟:
def pick(data=None, index=None, axis=_Null, keepdims=_Null, mode=_Null, out=None, name=None, **kwargs):
return (0,)
pick使用实例:
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
x = mx.nd.array(x)
ret = mx.nd.pick(x, mx.nd.array([0,0,0]), axis=1)
print(ret)
>>> [1. 3. 5.]
<NDArray 3 @cpu(0)>
ret = mx.nd.pick(x, mx.nd.array([0,1]), axis=0)
print(ret)
>>> [1. 4.]
<NDArray 2 @cpu(0)>
到此就可以理解整个过程了。