反向传播算法的公式推导

概念

反向传播(Back Propagation, BP)算法是使用梯度下降法相关的算法来优化一个神经网络时计算每一层梯度的方法,主要使用了多元函数的链式法则

已知多元函数 u = g ( y 1 , y 2 , . . . , y m ) ,且 y i = f i ( x ) ,所有函数都可微,则

u x = i = 1 m u y i y i x

公式推导

1、模型

不失一般性,我们考虑以下4层结构的神经网络(全连接):
这里写图片描述

2、符号说明

符号 含义
n l 网络层数
y j 输出层第 j 类标签
S l l 层神经元个数(不包括偏置)
g ( x ) 激活函数
w i j ( l ) l 层第 j 个单元与第 l + 1 层第 i 个单元之间的链接参数
b i ( l ) l 层的偏置与第 l + 1 层第 i 个单元之间的链接参数
z i ( l ) l 层第 i 个单元的输入(加权和,包括偏置)
a i ( l ) l 层第 i 个单元的输出(激活函数的值)
δ i ( l ) l 层第 i 个单元的输入的偏导(或称为灵敏度、残差)
J ( θ ) 代价函数

3、符号定义

z i ( l ) = b i ( l 1 ) + j = 1 S l 1 w i j ( l 1 ) a j ( l 1 ) a i ( l ) = g ( z i ( l ) ) J ( θ ) = 1 2 j = 1 S l ( y j a j ( l ) ) 2 δ i ( l ) = J ( θ ) z i ( l )

4、推导过程

δ i ( n l ) = J ( θ ) z i ( n l ) = 1 2 z i ( n l ) j = 1 S n l ( y j a j ( n l ) ) 2 = 1 2 z i ( n l ) j = 1 S n l ( y j g ( z j ( n l ) ) ) 2 = 1 2 z i ( n l ) ( y j g ( z i ( n l ) ) ) 2 = ( y i a i ( n l ) ) g ( z i ( n l ) ) δ i ( l ) = J ( θ ) z i ( l ) = j = 1 S l + 1 J ( θ ) z j ( l + 1 ) z j ( l + 1 ) z i ( l ) = j = 1 S l + 1 δ j ( l + 1 ) z j ( l + 1 ) z i ( l ) = j = 1 S l + 1 δ j ( l + 1 ) z i ( l ) ( b j ( l ) + k = 1 S l w j k ( l ) a k ( l ) ) = j = 1 S l + 1 δ j ( l + 1 ) z i ( l ) ( b j ( l ) + k = 1 S l w j k ( l ) g ( z k ( l ) ) ) = j = 1 S l + 1 δ j ( l + 1 ) z i ( l ) ( w j i ( l ) g ( z i ( l ) ) ) = j = 1 S l + 1 δ j ( l + 1 ) w j i ( l ) g ( z i ( l ) ) = g ( z i ( l ) ) j = 1 S l + 1 δ j ( l + 1 ) w j i ( l ) J ( θ ) w i j ( l ) = J ( θ ) z i ( l + 1 ) z i ( l + 1 ) w i j ( l ) = δ i ( l + 1 ) z i ( l + 1 ) w i j ( l ) = δ i ( l + 1 ) w i j ( l ) ( b i ( l ) + k = 1 S l w i k ( l ) a k ( l ) ) = δ i ( l + 1 ) a j ( l ) J ( θ ) b i ( l ) = δ i ( l + 1 ) b i ( l ) ( b i ( l ) + k = 1 S l w i k ( l ) a k ( l ) ) = δ i ( l + 1 )

向量形式的公式

δ ( l ) = ( W ( l ) ) T δ ( l + 1 ) g ( z ( l ) ) J ( θ ) W ( l ) = δ ( l + 1 ) ( a ( l ) ) T J ( θ ) b ( l ) = δ ( l + 1 )

其中, 表示每个元素相乘,粗体的小写符号表示列向量,粗体的大写符号表示矩阵。

参考

([1] 中的公式推导有错误,本文已纠正)
[1] https://www.cnblogs.com/nowgood/p/backprop.html
[2] Bouvrie J. Notes on convolutional neural networks[J]. 2006.

猜你喜欢

转载自blog.csdn.net/HappyRocking/article/details/80435544