cs231 softmax函数求导

版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,清华大学出版社官方旗舰店(天猫)https://qhdx.tmall.com/?spm=a220o.1000855.1997427721.d4918089.4b2a2e5dT6bUsM https://blog.csdn.net/duan_zhihua/article/details/82925795

cs231 softmax函数求导:

import numpy as np


def softmax_loss_naive(W, X, y, reg):
    """
    Softmax loss function, naive implementation (with loops)
    Inputs have dimension D, there are C classes, and we operate on minibatches
    of N examples.
    Inputs:
    - W: A numpy array of shape (D, C) containing weights.
    - X: A numpy array of shape (N, D) containing a minibatch of data.
    - y: A numpy array of shape (N,) containing training labels; y[i] = c means
      that X[i] has label c, where 0 <= c < C.
    - reg: (float) regularization strength
    Returns a tuple of:
    - loss as single float
    - gradient with respect to weights W; an array of same shape as W
    """
    # Initialize the loss and gradient to zero.
    loss = 0.0
    dW = np.zeros_like(W)

    #############################################################################
    # TODO: Compute the softmax loss and its gradient using explicit loops.     #
    # Store the loss in loss and the gradient in dW. If you are not careful     #
    # here, it is easy to run into numeric instability. Don't forget the        #
    # regularization!                                                           #
    #############################################################################
    for i in range(X.shape[0]):
        score = np.dot(X[i], W)
        score -= max(score)  # 为了数值稳定性
        score = np.exp(score)  # 取指数
        softmax_sum = np.sum(score)  # 得到分母
        score /= softmax_sum  # 除以分母得到softmax
        # 计算梯度
        for j in range(W.shape[1]):
            if j != y[i]:
                dW[:, j] += score[j] * X[i]
            else:
                dW[:, j] -= (1 - score[j]) * X[i]

        loss -= np.log(score[y[i]])  # 得到交叉熵
    loss /= X.shape[0]  # 平均
    dW /= X.shape[0]  # 平均
    loss += reg * np.sum(W * W)  # 加上正则项
    dW += 2 * reg * W
    #############################################################################
    #                          END OF YOUR CODE                                 #
    #############################################################################

    return loss, dW

https://github.com/duanzhihua 

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/82925795
今日推荐