【cs231n作业笔记】二:SVM分类器

可以参考:cs231n assignment1 SVM 完整代码

231n作业   多类 SVM 的损失函数及其梯度计算(最好)

其中对多分类SVM损失函数的推导先不赘述,最后得到一个对N个样本计算梯度并返回梯度与损失的矩阵,梯度部分如下:

def svm_loss_naive(W, X, y, reg):
  """
  Structured SVM loss function, naive implementation (with loops).

  Inputs have dimension D, there are C classes, and we operate on minibatches
  of N examples.

  Inputs:
  - W: A numpy array of shape (D, C) containing weights.
  - X: A numpy array of shape (N, D) containing a minibatch of data.
  - y: A numpy array of shape (N,) containing training labels; y[i] = c means
    that X[i] has label c, where 0 <= c < C.
  - reg: (float) regularization strength

  Returns a tuple of:
  - loss as single float
  - gradient with respect to weights W; an array of same shape as W
  """
  dW = np.zeros(W.shape) # initialize the gradient as zero

  # compute the loss and the gradient
  num_classes = W.shape[1]
  num_train = X.shape[0]
  loss = 0.0
  for i in range(num_train):
    scores = X[i].dot(W)
    correct_class_score = scores[y[i]]
    for j in range(num_classes):
      if j == y[i]:
        continue
      margin = scores[j] - correct_class_score + 1 # note delta = 1
      if margin > 0:
        loss += margin
        dW[:,j]+=X[i] #数据分类错误时的梯度
        dW[:,y[i]]-=X[i] #数据分类正确时的梯度,所有非正确的累减


  # Right now the loss is a sum over all training examples, but we want it
  # to be an average instead so we divide by num_train.
  loss /= num_train
  dW /=num_train
  # 加上正则项的部分:reg?
  loss += reg * np.sum(W * W)
  dW+=reg*np.sum(W)  #reg是正则化强度的量

  #############################################################################
  # TODO:                                                                     #
  # Compute the gradient of the loss function and store it dW.                #
  # Rather that first computing the loss and then computing the derivative,   #
  # it may be simpler to compute the derivative at the same time that the     #
  # loss is being computed. As a result you may need to modify some of the    #
  # code above to compute the gradient.                                       #
  #############################################################################


  return loss, dW

猜你喜欢

转载自www.cnblogs.com/joelwang/p/10824441.html
今日推荐