斯坦福CS231n assignment1:softmax损失函数求导

版权声明:原创文章,保留所有版权 https://blog.csdn.net/rootmego/article/details/84327899

斯坦福CS231n assignment1:softmax损失函数求导

分类

在前文斯坦福CS231n assignment1:SVM图像分类原理及实现中我们讲解了利用SVM模型进行图像分类的方法,本文我们讲解图像分类的另一种实现,利用softmax进行图像分类。

softmax和svm模型网络结构很相似,区别在于softmax会对svm的输出分量进行归一化处理,使得每一个输出分量变成一个概率值,所有输出分量的概率之和为1。
归一化概率

同时损失函数也发生了变化,svm的损失函数折叶损失(hinge loss)是针对样本的标记类别之外的其他类别进行损失计算的,也就是说标记类别不计入损失,其他类别计算损失并累加作为某个样本的损失。而softmax的损失函数交叉熵损失(cross-entropy loss)只跟某个样本的标记类别相关,根据该标记类别的概率计算损失值,而不考虑标记类别之外的其他类别。
svm和softmax损失函数的计算比较

svm得出的每个输出节点的得分,比如[98, 33, 15]是无标定的,也就是只是一个相对的大小,难以进行直观的解释。而softmax可以解释为实例被划分为某个类别的可能性,或者概率。

下面是softmax的损失函数:

softmax损失函数

也可以等价成:
softmax损失函数等价写法

加入正则化损失项后,批处理过程中N个样本的平均损失变成:
加入正则项的损失函数

这里我们使用L2正则化损失:
L2正则化损失项

在此基础上我们来推导损失函数L对权重Wij的偏导数,推导过程如下:
softmax交叉熵梯度计算

在这个推导过程中需要注意的是,直接跟标记类对应的输出节点相连的权重和不跟标记类节点相连的权重的偏导数格式是不一样的,对应于推导过程中的if/else判别。

对应的代码如下:

def softmax_loss_naive(W, X, y, reg):
  """
  :param X: 200 X 3073
  :param Y: 200
  :param W: 3073 X 10
  :return: reg: 正则化损失系数(无法通过拍脑袋设定,需要多试几个值,然后找个最优的)
  """
  dW = np.zeros(W.shape) # initialize the gradient as zero

  # compute the loss and the gradient
  num_classes = W.shape[1]
  num_train = X.shape[0]
  loss = 0.0
  for k in xrange(num_train):
    origin_scors = X[k].dot(W)
    probabilities = np.zeros(origin_scors.shape)
    logc = -np.max(origin_scors)
    total_sum = np.sum(np.exp(origin_scors - logc))

    for i in xrange(num_classes):
        probabilities[i] = np.exp(origin_scors[i] - logc) / total_sum

    for i in xrange(num_classes):
        if i == y[k]:
            dW[:, i] += - X[k] * (1 - probabilities[i])  # dW[:, i]:3073X1  X[k]: 3073 X 1
        else:
            dW[:, i] += X[k] * probabilities[i]

    loss += -np.log(probabilities[y[k]])

  # Right now the loss is a sum over all training examples, but we want it
  # to be an average instead so we divide by num_train.
  loss /= num_train
  dW /= num_train
  dW += reg*W # regularize the weights
  # Add regularization to the loss.
  loss += 0.5 * reg * np.sum(W * W)

  return loss, dW

猜你喜欢

转载自blog.csdn.net/rootmego/article/details/84327899