在实际机器学习工作中,最常用的就是实值函数 y 对向量 x或矩阵 X 求导,比如最简单的线性回归问题中由目标函数
dJ(w)
求解最佳参数向量
w
。
本文以线性回归问题中由目标函数
dJ(w)
求解最佳参数向量
w
问题为例子,介绍个人总结的一点机器学习矩阵求导的的技巧和方法,其中包括:
1. 全微分与偏导数关系
2. 迹技巧
3. 常用的矩阵求导公式
一. 利用矩阵偏导数与微分的关系
1.1 实值函数对向量的微分
1.2 实值函数对矩阵的微分
1.3 上面两个公式的应用
1.4 运算法则
- 加减法:
d(X±Y)=dX±dY
- 矩阵乘法:
d(XY)=dXY+XdY
- 转置:
d(XT)=(dX)T
- 迹:
dtr(X)=tr(dX)
- 逆:
dX−1=−X−1dXX−1
此式可在
XX−1=I
两侧求微分来证明。
1.5 迹技巧
- 标量套上迹:
a=tr(a)
- 转置:
tr(AT)=tr(A)
- 线性:
tr(A±B)=tr(A)±tr(B)
。
- 矩阵乘法交换:
tr(AB)=tr(BA)
二 用迹的性质简化矩阵求导问题。
性质1
tra=a,tr(aA)=a∗trA
,a为标量
性质2
tr(A+B)=trA+trB
性质3
trAB=trBA,trABC=trCAB=trBCA
性质4
trA=trAT
性质5
▽Atr(AB)=BT
性质6
▽Atr(ABATC)=CAB+CTABT
实例计算:使用迹的技巧求解线性回归的最佳参数。
▽wJ(w)=▽wtrJ(w)
=▽wtr(Xw−Y)T(Xw−Y)
=▽wtr(wTXTXw−YTXw−wTXTY+YTY)
- 注:
- 这里应该明确的是J(w) 是两个向量的内积,因此为标量,可以应用性质1: tr a = a
-
▽wJ(w)
是标量J(w)对一个向量 w 求导,其结果是一个向量,维数和w向量相同。
▽wJ(w)=▽wtrJ(w)
=▽wtr(Xw−Y)T(Xw−Y)
=▽wtr(wTXTXw−YTXw−wTXTY+YTY)
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY)
▽wJ(w)=▽wtrJ(w)
=▽wtr(Xw−Y)T(Xw−Y)
=▽wtr(wTXTXw−YTXw−wTXTY+YTY)
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−2∗▽wtr(YTXw)
- 注:
- 这里应用
trAB=trBA(A=wTXTX,B=w)
- 以及
trAT=trA(A=wTXTY)
▽wJ(w)=▽wtrJ(w)
=▽wtr(Xw−Y)T(Xw−Y)
=▽wtr(wTXTXw−YTXw−wTXTY+YTY)
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−2∗▽wtr(YTXw)
=▽wtr(wIwTXTX)−2∗▽wtr(YTXw)
=(XTXwI+XTXIw)−2∗▽wtr(YTXw)
- 注:
- 这里应用
▽AtrABATC=CAB+CTABT(A=w,C=XTX,B=I),I是1维单位矩阵
- 以及
trAT=trA(A=wTXTY)
▽wJ(w)=▽wtrJ(w)
=▽wtr(Xw−Y)T(Xw−Y)
=▽wtr(wTXTXw−YTXw−wTXTY+YTY)
=▽wtr(wTXTXw)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−▽wtr(YTXw)−▽wtr(wTXTY)
=▽wtr(wwTXTX)−2∗▽wtr(YTXw)
=▽wtr(wIwTXTX)−2∗▽wtr(YTXw)
=(XTXwI+XTXIw)−2∗▽wtr(YTXw)
=2∗XTXw−2∗XTYw
- 注:
- 这里应用
▽Atra(AB)=BT
三. 机器学习中常用的矩阵求导
矩阵/向量求导问题中要明确是什么量对什么量求导,得到的是什么形式的量
- 实值函数对向量求导,结果是同样维度和方向的向量
- 实值函数对矩阵求导,结果是同样维度的矩阵
重要的矩阵求导公式:公式证明可以用微分分解加迹技巧证明。
证明第一条公式:
d(xTAx)=d(xT)Ax+xTd(Ax)
=(Ax)Tdx+xT(AT)Tdx
=(xTAT+xTA)dx
则:
∂xTAx∂x=(xTAT+xTA)T=(AT+A)x
- 例子:线性回归问题中由目标函数
dJ(w)
求解最佳参数向量
w
问题
▽wJ(w)
=▽w(Xw−Y)T(Xw−Y)
=▽w(wTXTXw−YTXw−wTXTY+YTY)
=▽w(wTXTXw)−▽w(YTXw)−▽w(wTXTY)
=2∗XTXw−XTY−XTY
=2∗XTXw−2∗XTY
- 注:求导公式忘了可以用微分转换和迹技巧推导。