梯度更新方法总结

梯度更新

举例说明

对于逻辑归回梯度求解:

  • 假设预测输出函数:

    h ( x 0 , x 1 , . . . , x n ) = i = 0 n θ i x i + θ 0

    i:一次输入中,第i个数据

  • 实际输出: y

  • Cost Function函数:

    J ( θ 0 , θ 1 , . . . , θ n ) = 1 m j = 0 m ( y ( j ) h ( j ) ) 2

    j:第j个输入,总共输入m个数据

  • 梯度求解的最终目的:令 J 的值 最小,根据微积分,只要求出 J = d J d θ = J θ i 0

    根据导数的定义: J = J J J 肯定是往函数最小值方向移动,即 J 0 J min

  • 但是通常情况下,由于 J 很难直接求解出来,换个思路就是通过更新参数 θ i ( j ) 来实现: J 0 J min

    • 如何保证参数 θ i 更新方向是 J min

      求解:

      θ i + = θ i Δ θ i
      根据导数定义,如果保证 Δ θ i 和偏导数 J θ i 数值正负号一致,上述参数 θ i 更新后, J min

    • 如何确定每次参数要更新多少?

      对于参数更新公式:

      θ i + = θ i Δ θ i
      只是保证了参数更新方向正确,为了防止一步跨度太大,最终 J 的值一下子垮过min,从而造成抖动,需要引入学习效率 η ,最终公式:
      θ i + = θ i η Δ θ i

      一般: η [ 0 , 1 ]

  • 小结:

    θ i + = θ i η Δ θ i
    该公式属于人造并非数学推导,主要是符合了参数更新方向、同时人为设定更新步长

PS:某个权重更新的值=0,即 J θ i = 0 ,并不代表 d J d θ = 0

常见梯度更新方法

参考:http://blog.csdn.net/boon_228/article/details/51721835

BGD

批次梯度下降

  • 概念:每次更新所有样本数据来更新一次 J ( θ ) 的参数 θ
  • 预测函数:
    h ( x 0 , x 1 , . . . , x n ) = i = 0 n θ i x i + θ 0
  • 对于cost function:
    J ( θ 0 , θ 1 , . . . , θ n ) = 1 m j = 0 m ( y ( j ) h ( j ) ) 2
  • 有参数更新公式:
    θ i + = θ i η J θ i = θ i η 2 m j = 0 m ( y ( j ) h ( j ) ) x i
  • 这种参数更新方法是批量梯度更新,也就是每次更新 θ i 都需要用到这样本里所有数据
  • 小结:
    图1 -w300
    • 优点:没更新一次,都用所有样本数据进行更新,这样就求解了全局最优解,同时通过计算公式可以发现可以并行实现;
    • 缺点:由于每次更新都要计算该批次训练样本数据,如果批次样本数量太大,训练过程会很慢

SGD

随机梯度下降

  • 概念:由于BGD每次更新都需要用到批次里所有样本数据,所以引入随机梯度下降概念;和BGD的区别在于,每次训练样本只输入一个,通过不断输入不同样本来修正参数,而BGD是一次输入所有样本来修正参数

    区别 BGD SGD
    h ( x ) i = 0 n θ i x i + θ 0 i = 0 n θ i x i + θ 0
    J ( θ ) 1 m j = 0 m ( y ( j ) h ( j ) ) 2 1 2 ( y h ) 2
    每次输入样本数 所有 1个
    循环次数 一次 =样本数量
  • 小结:
    图1 -w300

    • 优点:训练速度快;
    • 缺点:准确度下降(噪音比BGD大),并非全局最优解;不易于并行实现;

MBGD

小批量梯度下降

  • 概念:结合了BGD和SGD的优点:将所有的样本分割成很多小份,每次用这个小样本进行BGD训练,即:

    for SGD:
        for BGD:
            ...
  • 小结:
    • 目前算法比较常用的梯度下降算法用MBGD,常用的小份样本数量有:64、10

常用梯度更新公式推导

神经网络常用梯度更新公式推导

Pooling

  1. 在池化层,设置了固定的w,所以参数不在此更新
  2. 主要类似设置了反向传播的阀门,保证反向阀门开合以及打开大小
  3. 这里的pooling方法是一般池化,即池化过程中,各个模块的边界不重叠。除了不重叠方法外,还有重叠池化、空金字塔池化

Max Pooling

  • 概念:反向传播求导数过程类似分段函数求导

    分段函数在分段点的导数必须分别求左右导数,而在非分段点的导数如常

    图3 -w300

  • 有矩阵如下:

    1 2 3 4 5 6 7 8 9

  • 求解整个矩阵的max_pooling,则:

    y = max ( x 1 , x 2 , . . . , x 9 ) = 9

  • 反向传播时:

    y x n = { 0 , x n 1 , x n

    这个公式可以通过如下代码的条件语句实现:

    if x_n == y:
        y_d = 1
    else:
        y_d = 0
  • 根据反向传播公式的链式原理:

    w + = w η Δ w = w η J w | x = x 0 = w η J O u t 1 O u t 1 N e t 1 . . . N e t i w | x = x 0

    如果pooling层对 x i 偏导数为0,通过 x i 向前的参数修正值都恒为0,即不传播;

  • 小结:
    对于max pooling参数只通过max值那个点反向传播,在上述例子中,即通过 x 9 向前修正参数,其他值均认为是无用数据丢弃;

Mean Pooling

  • 概念:求解n x m矩阵的上所有点的均值

    图4 -w300

  • 有矩阵如下:

    1 2 3 4 5 6 7 8 9

  • 求解整个矩阵的mean pooling,则:
    y = ( x 1 + x 2 + + x n ) n = 5
  • 反向传播时:
    y x n = 1 n = 1 9
  • 小结:
    对于mean pooling ,反向传播过程中,只是给传播链条添加一个常数 1 n ,即作为固定权重使用

激活函数

同Pooling,没有需要更新的权重,起到了传播过程中的阀门作用

ReLU

  • 公式:
    y = { 0 , x 0 x , x > 0
  • 反向传播时:
    d y d x = { 0 , x 0 1 , x > 0
  • 小结:
    对于输出值: x 0 的神经元,流经ReLU层后的反向传播,该神经元以及前面对应的参数都会被丢弃,即不再更新参数

Sigmoid

  • 公式:
    y = 1 1 + e x
  • 反向传播时:
    d y d x = y ( 1 y )

tanh

  • 公式:
    y = e x e x e x + e x
  • 反向传播时:
    d y d x = 2 ( 1 y ) ( 1 + y )

soft Max

  • 公式:

    y i = e x i i = 0 n e x i

    x i :第i个输入值
    y i x i 对应的输出值

  • 反向传播时:
    y i x i = y i ( 1 y i )

Loss函数

交叉熵

理论上均方差 C = ( y a ) 2 n 值小的程度,作为判断神经元预测结果和实际结果的偏离程度很好理解
但是,实际问题中,针对分类问题,交叉熵的表现比均方差来的好

  • 设经过soft Max的神经元输出:
    z 1 ( i = 1 ) 2 ( i = 2 ) 3 ( i = 3 ) 1 ( k = 1 ) 0.9 0.1 0.1 2 ( k = 2 ) 0.1 0.9 0.1 3 ( k = 2 ) 0.1 0.1 0.9
  • 实际结果:
    y 1 ( i = 1 ) 2 ( i = 2 ) 3 ( i = 3 ) 1 ( k = 1 ) 1 0 0 2 ( k = 2 ) 0 1 0 3 ( k = 3 ) 0 0 1
  • 交叉熵公式:

    C k = 1 n i = 1 n [ y i ( k ) ln z i ( k ) + ( 1 y i ( k ) ) ln ( 1 z i ( k ) ) ] ,   y i ( k ) { 0 , 1 } C 1 0.035 C 2 0.035 C 3 0.035 C = 1 m k = 1 m ( C k ) 0.035

    C k :第k个样本的交叉熵
    z i :某个样本的第i个类别预测结果
    y i :某个样本的第i个类别实际结果

  • 反向传播时, y i 是已知值:
    C z i = C C k C k z i = 1 m ( y i ( k ) n z i ( k ) ) ,   y i ( k ) { 0 , 1 }

梯度更新代码实现

TBD

猜你喜欢

转载自blog.csdn.net/QQ2627866800/article/details/79351546
今日推荐