近期把DNN的反向传播又好好的研究了一下。之前一直有疑虑是因为很多文档里边出现
∂z(l+1)∂z(l)
这种表达式,然后
z(l+1)
和
z(l)
还是矩阵,这下就变得非常烦人了,因为没有哪本数学书定义了矩阵对矩阵的导数。只有标量函数对矩阵,矩阵对标量,标量对向量,向量对标量以及向量对向量。所以我觉得有必要在好好把这块弄一下,写清楚。
首先是DNN的模型:
⎧⎩⎨⎪⎪⎪⎪z(l+1)=θ(l+1)⋅a(l)+b(l+1)⋅1T,a(l)=g(z(l+1)),J=J(a(N))l=1,2,3,…,N(1028)
这里边,
a(1)=X
也就是输入,
1
是列向量。然后:
X=⎛⎝⎜|X1|………|Xm|⎞⎠⎟(1029)
也就是说,一共有m个样本。
通常的文章怎么描述的呢?定义
δ(l)=∂J∂z(l)
,假如计算出了
δ(l)
那么
∂J∂θ(l)=∂J∂z(l)⋅∂z(l)∂θ(l)
,然后
∂J∂z(l−1)=∂J∂z(l)⋅∂z(l)∂z(l−1)
,由于
∂J∂z(N)
很容易计算,所以后边递推就可以了。但是问题在于
∂z(l)∂z(l−1)
到底是啥?雅可比矩阵吗?
z(l)
和
z(l−1)
都是矩阵,没有一本数学书有这么直接写的。矩阵对矩阵的导数目前还处于undefined的状态。所以这个符号其实是没有严格定义的。只不过按照其他的方式推导出来后,结果看上去很像,所以就这么写了,但是如果真的较真说这个矩阵对矩阵的定义是什么怎么算,那就没法严格的说了。所以这篇文章就是仔细的把这块严格的做一下。
然后有几个公式定理需要推导一下,推到完了,很多东西就迎刃而解了。
f:Rm×n↦R
也就是一个矩阵的标量函数,那么若
g:Rp×q↦Rm×n
,那么复合函数:
f∘g:Rp×q↦R
,例如
f(z), z=θX
,又如
f(a), a=g(z)
。在这种情况下,我们希望得到
∂f∂θ
或者
∂f∂z
,该如何求解?其实这种情况,需要用到matrix vectorization和kronecker product,但是我们所遇到的恰好是线性变换和element-wise function,所以对于这两种情况,完全可以简化。
Lemma 1
若
g
是一个矩阵左乘或者右乘,也就是
g=θX
这种情况,那么有:
[∂f∂X]i,j=∑m∑n∂f∂gm,n⋅∂gm,n∂Xi,j=∑m∑n∂f∂gm,n⋅∂∑kθm,kXk,n∂Xi,j=∑m∂f∂gm,j⋅θm,i=[θT⋅∂f∂g]i,j(1030)
因此:
∂f∂X=θT⋅∂f∂g
其中第一个等号是全微分公式,第二个等号是矩阵乘法展开,第三个等号是因为
k≠i, n≠j
时
∂θm,kXk,n∂Xi,j=0
,最后一个等号就是矩阵乘法了。
同理:
∂f∂θ=∂f∂g⋅XT
Lemma 2
假如
g
是一个非线性函数,但是是一个element-wise的函数,那么:
[∂f∂X]i,j=∑m∑n∂f∂gm,n⋅∂gm,n∂Xi,j=[∂f∂a]i,j⋅[g′(z)]i,j(1031)
因此:
∂f∂X=∂f∂a⊙g′(z)(1032)
这里边
⊙
是hardamard product,其实就是元素乘法。
有了Lemma 1和Lemma 2之后很多东西就迎刃而解了。定义
δ(l)=∂J∂z(l)
,而
z(l)=θ(l)⋅a(l−1)+b(l)⋅1T
那么显然:
∂J∂θ(l)∂J∂b(l)=δ(l)⋅(a(l−1))T=δ(l)⋅1(1033)
那么对于有了
δ(l+1)
计算
δ(l)
呢?首先由于
z(l+1)=θ(l+1)⋅a(l)+b(l+1)⋅1T
,所以:
∂J∂a(l)=(θ(l+1))T⋅δ(l+1)
这里用了Lemma 1的第一个,然后根据Lemma 2,
a(l)=g(z(l))
,因此:
∂J∂z(l)=(θ(l+1))T⋅δ(l+1)⊙g′(z(l))(7)
这样就完成了推导。
我认为这种方式比
∂z(l+1)∂z(l)
这种写法要清晰明白很多,因为矩阵对矩阵的导数一定是得每个元素都要求导。这样就出来一个mn x mn矩阵了,但是目前这种方式,就明白清晰了很多。
另外如果吧bias一项放入
θ
里边去,然后
a(l)
不上一行1,也是可以的,就直接用:
∂J∂θ(l)=δ(l)⋅(a(l−1))T
即可。