本文由**罗周杨[email protected]**原创,转载请注明出处。 本文已经发表在原作者博客 blog.stupidme.me/2018/08/25/…
反向传播是深度学习的基石。
导数
先回顾下导数:
函数在每个变量的导数就是偏导数。
对于函数,,同时
梯度就是偏导数组成的矢量。上述例子中,
链式法则
对于简单函数,我们可以根据公式直接计算出其导数。但是对于复杂的函数,我们就没那么容易直接写出导数。但是我们有链式法则(chain rule)。
定义不多说,咱们举个例子,感受一下链式法则的魅力。
我们熟悉的sigmoid函数 ,如果你记不住它的导数,我们怎么求解呢?
求解步骤如下:
- 将函数模块化,分成多个基本的部分,对于每一个部分都可以使用简单的求导法则进行求导
- 使用链式法则,将这些导数链接起来,计算出最终的导数
具体如下:
令,则
令,则
令,则
令,则
令,则
上面的e实际上就是我们的$$\sigma(x)$$,那么根据链式法则,有:
sigmoid函数的导数可以直接用自身表示,这也是很奇妙的性质了。这样的求导过程是不是很简单?
反向传播代码实现
求导和链式法则我都会了,那么具体的前向传播和反向传播的代码是怎么样的呢?
这次我们使用一个更复杂一点点的例子:
我们先看下它地forward pass代码:
import math
x = 3
y = -4
sigy = 1.0 / (1 + math.exp(-y)) # sigmoid function
num = x + sigy # 分子
sigx = 1.0 / (1 + math.exp(-x))
xpy = x + y
xpy_sqr = xpy**2
den = sigx + xpy_sqr # 分母
invden = 1.0 / den
f = num * invden # 函数
复制代码
上述过程很简单对不对,就是把复杂的函数拆解成一个一个简单函数。
我们看看接下来的反向传播过程:
dnum = invden
复制代码
因为
所以有
也就是
dinvden = num # 同理
dden = (-1.0 / (den**2)) * dinvden # 链式法则
复制代码
展开来说:
又
所以
所以,同理,我们可以写出所有的导数:
dsigx = (1) * dden
dxpy_sqr = (1) * dden
dxpy = (2 * xpy) * dxpy_sqr
# backprob xpy = x + y
dx = (1) * dxpy
dy = (1) * dxpy
# 这里开始,请注意使用的是"+=",而不是"=”
dx += ((1 - sigx) * sigx) * dsigx # dsigma(x) = (1 - sigma(x))*sigma(x)
dx += (1) * dnum
# backprob num = x + sigy
dsigy = (1) * dnum
# 注意“+=”
dy += ((1 - sigy) * sigy) * dsigy
复制代码
问题:
- 上面计算过程中,为什么要用“+=”替代“=”呢?
如果变量x,y在前向传播的表达式中出现多次,那么进行反向传播的时候就要非常小心,使用+=而不是=来累计这些变量的梯度(不然就会造成覆写)。这是遵循了在微积分中的多元链式法则,该法则指出如果变量在线路中分支走向不同的部分,那么梯度在回传的时候,就应该进行累加。
联系我
-
Email: [email protected]
-
WeChat: luozhouyang0528
-
个人公众号,你可能会感兴趣: