為何邏輯回歸的損失函數是用交叉熵而非均方誤差?

前言

本文是筆者在學習吳恩達的深度學習課程時所碰到的問題。
課程中雖然有提及將均方誤差用於邏輯回歸可能會造成多個局部最小值,但是並未給出具體例子。
而本篇文章將嘗試給出幾個例子,並說明其背後的原因。

在進入正文以前,必須先澄清一點:
loss function(損失函數)指的是單一個樣本的誤差。
而cost function(代價函數,成本函數)指的是數據集中所有樣本誤差的均值。

邏輯回歸的損失函數推導

邏輯回歸的輸出值 y ^ \hat{y} 表示的是當前輸入 x x 是屬於 y = 1 y=1 這個類別的機率,用數學的語言來說明的話,就是:
P ( y = 1 x ) = y ^ P(y=1|x) = \hat{y}
輸入x是屬於y=0這個類別的機率則是:
P ( y = 0 x ) = 1 y ^ P(y=0|x) = 1-\hat{y}
我們可以將上面這兩個式子可以合併,代表正確地將 x x 分到它所屬的類別 y y 的機率,變成:
P ( y x ) = y ^ y ( 1 y ^ ) 1 y P(y|x) = \hat{y}^{y}(1-\hat{y})^{1-y}
我們希望這個機率越大越好。

為了簡化計算,在等式兩邊取log:
log ( P ( y x ) ) = y log ( y ^ ) + ( 1 y ) log ( 1 y ^ ) \log(P(y|x)) = y\log(\hat{y})+(1-y)\log(1-\hat{y})

我們原來的目標是希望最大化 P ( y x ) P(y|x) ,經過變換後,目標變成最大化 log ( P ( y x ) ) \log(P(y|x))

log是嚴格遞增函數,如果輸入 x 1 > x_1> 輸入 x 2 x_2 ,則輸出 log ( x 1 ) > \log(x_1)> 輸出 log ( x 2 ) \log(x_2) 。即兩數在對應前後的大小關係不會被改變。
因為log函數有這種特性,所以我們可以說最大化 P ( y x ) P(y|x) 與最大化 log ( P ( y x ) ) \log(P(y|x)) 這兩個目標是一致的。

我們希望 log ( P ( y x ) ) \log(P(y|x)) 最大化,等義地,即是 log ( P ( y x ) ) -\log(P(y|x)) 最小化。因此我們可以將這個值當成loss,作為要優化的目標。

以下就是推導出來的交叉熵損失函數
C r o s s E n t r o p y L o s s = [ y log ( y ^ ) + ( 1 y ) log ( 1 y ^ ) ] CrossEntropyLoss = -[y\log(\hat{y})+(1-y)\log(1-\hat{y})]

交叉熵代價函數則是所有交叉熵損失函數的平均:
C r o s s E n t r o p y C o s t = 1 m i = 1 m [ y ( i ) log ( y ^ ( i ) ) + ( 1 y ( i ) ) log ( 1 y ^ ( i ) ) ] CrossEntropyCost = -\frac{1}{m}\sum_{i=1}^{m} [y^{(i)}\log(\hat{y}^{(i)})+(1-y^{(i)})\log(1-\hat{y}^{(i)})]

使用均方誤差,會出現多個局部最小值?

吳恩達在課程中有提到如果使用均方誤差,就會出現多個局部最小值,導致收斂困難。
讓我們來看個具體例子:

為了後續計算方便,此處假設 b = 0 b=0 。因此 y ^ = w x \hat{y}=wx

我們先來看看使用交叉熵損失函數的情況:
C r o s s E n t r o p y L o s s = [ y l o g ( y ^ ) + ( 1 y ) l o g ( 1 y ^ ) ] = [ y l o g ( 1 1 + e x w ) + ( 1 y ) l o g ( 1 1 1 + e x w ) ] CrossEntropyLoss \\= -[ylog(\hat{y})+(1-y)log(1-\hat{y})] \\= -[ylog(\frac{1}{1+e^{-xw}})+(1-y)log(1-\frac{1}{1+e^{-xw}})]
交叉熵代價函數:
C r o s s E n t r o p y C o s t = 1 m i = 1 m [ y ( i ) l o g ( y ^ ( i ) ) + ( 1 y ( i ) ) l o g ( 1 y ^ ( i ) ) ] = 1 m i = 1 m [ y ( i ) l o g ( 1 1 + e x ( i ) w ) + ( 1 y ( i ) ) l o g ( 1 1 1 + e x ( i ) w ) ] CrossEntropyCost \\= -\frac{1}{m}\sum_{i=1}^{m}[y^{(i)}log(\hat{y}^{(i)})+(1-y^{(i)})log(1-\hat{y}^{(i)})] \\= -\frac{1}{m}\sum_{i=1}^{m}[y^{(i)}log(\frac{1}{1+e^{-x^{(i)}w}})+(1-y^{(i)})log(1-\frac{1}{1+e^{-x^{(i)}w}})]

再來看平方損失函數:
S q u a r e L o s s = ( y y ^ ) 2 = ( y 1 1 + e x w ) 2 SquareLoss \\= (y-\hat{y})^2 \\= (y-\frac{1}{1+e^{-xw}})^2

均方誤差函數:
M e a n S q u a r e d E r r o r = 1 m i = 1 m [ ( y ( i ) 1 1 + e x ( i ) w ) 2 ] MeanSquaredError \\= \frac{1}{m}\sum_{i=1}^{m}[(y^{(i)}-\frac{1}{1+e^{-x^{(i)}w}})^2]

在此假設{( x 1 x_1 , y 1 y_1 ), ( x 2 x_2 , y 2 y_2 ), ( x 3 x_3 , y 3 y_3 )}分別為{(0.84,1),(0.03,1),(0.15,0)}。

然後把這些值帶入上面 C r o s s E n t r o p y C o s t CrossEntropyCost M e a n S q u a r e d E r r o r MeanSquaredError 兩個式子:
C r o s s E n t r o p y C o s t = 1 3 [ 1 log ( 1 1 + e 0.84 w ) + ( 1 1 ) log ( 1 1 1 + e 0.84 w ) + 1 log ( 1 1 + e 0.03 w ) + ( 1 1 ) log ( 1 1 1 + e 0.03 w ) + 0 log ( 1 1 + e 0.15 w ) + ( 1 0 ) log ( 1 1 1 + e 0.15 w ) ] = 1 3 [ log ( 1 1 + e 0.84 w ) + log ( 1 1 + e 0.03 w ) + log ( 1 1 1 + e 0.15 w ) ] CrossEntropyCost \\= -\frac{1}{3}[1\log(\frac{1}{1+e^{-0.84w}})+(1-1)\log(1-\frac{1}{1+e^{-0.84w}})\\ +1\log(\frac{1}{1+e^{-0.03w}})+(1-1)\log(1-\frac{1}{1+e^{-0.03w}})\\ +0\log(\frac{1}{1+e^{-0.15w}})+(1-0)\log(1-\frac{1}{1+e^{-0.15w}})] \\= -\frac{1}{3}[\log(\frac{1}{1+e^{-0.84w}}) +\log(\frac{1}{1+e^{-0.03w}})+\log(1-\frac{1}{1+e^{-0.15w}})]

M e a n S q u a r e d E r r o r = 1 3 [ ( 1 1 1 + e 0.84 w ) 2 + ( 1 1 1 + e 0.03 w ) 2 + ( 0 1 1 + e 0.15 w ) 2 ] = 1 3 [ ( 1 1 1 + e 0.84 w ) 2 + ( 1 1 1 + e 0.03 w ) 2 + ( 1 1 + e 0.15 w ) 2 ] MeanSquaredError \\= \frac{1}{3}[(1-\frac{1}{1+e^{-0.84w}})^2+(1-\frac{1}{1+e^{-0.03w}})^2+(0-\frac{1}{1+e^{-0.15w}})^2] \\= \frac{1}{3}[(1-\frac{1}{1+e^{-0.84w}})^2+(1-\frac{1}{1+e^{-0.03w}})^2+(\frac{1}{1+e^{-0.15w}})^2]

圖形

我們可以利用desmos這個工具將以上兩個代價函數畫出來。

CrossEntropyCost

cross entropy loss3
想要進一步查看此圖可以前往 :https://www.desmos.com/calculator/ls4kcjmeab
從上圖中我們可以看出CrossEntropyCost(藍色)只有一個局部最小值(即是全局最小值)。

MeanSquareCost

square loss and mean squared error
想要更仔細地查看這張圖片,可以前往 https://www.desmos.com/calculator/vqqkhehydi
我們可以看到上圖中MeanSquaredError(紫色)在w=-9.582及2.002處都有局部最小值。

其它例子

使用MeanSquareCost會有多個局部最小值,讓我們再看看其它例子:
3
2
上面的例子是筆者用Python程式找出來的,如果有興趣,可以前往GitHub頁面查看。

理論證明

上面CrossEntropyCost的例子只有一個局部最小值(即全局最小值)。
而MeanSquaredError的例子中則有多個局部最小值。

這裡我們試著從數學的角度來看,試圖為上述現象給出理論支持。

凸函數

在此之前先簡單介紹一下convex函數(凸函數)。
維科百科中凸函數的定義:

A twice differentiable function of one variable is convex on an interval if and only if its second derivative is non-negative there

它說明了一個二次可微的函數是凸函數 若且唯若 它的二階導數是非負的。

又根據:

if w 1 , , w n 0 w 1 , , w n 0 {\displaystyle w_{1},\ldots ,w_{n}\geq 0} {\displaystyle w_{1},\ldots ,w_{n}\geq 0} and f 1 , , f n f 1 , , f n {\displaystyle f_{1},\ldots ,f_{n}} {\displaystyle f_{1},\ldots ,f_{n}} are all convex, then so is w 1 f 1 + + w n f n . w 1 f 1 + + w n f n . {\displaystyle w_{1}f_{1}+\cdots +w_{n}f_{n}.} {\displaystyle w_{1}f_{1}+\cdots +w_{n}f_{n}.}

即:n個凸函數的加權和仍是一個凸函數

CrossEntropyCost

因為CrossEntropyLoss = ( log ( x ) ) = ( 1 x ) = 1 x 2 (-\log(x))'' = (-\frac{1}{x})' = \frac{1}{x^2} ,所以我們知道CrossEntropyLoss 是凸函數。

而CrossEntropyCost是多個CrossEntropyLoss的平均值,所以它也是凸函數。

既然CrossEntropyCost是凸函數,那麼它只有一個局部最小值也不足為奇了。

MeanSquaredError

(SquareLoss when y=0) = ( ( 1 1 + e x w ) 2 ) = 2 x 2 e 2 x w ( e x w 2 ) ( e x w + 1 ) 4 (({\frac{1}{1+e^{-xw}}})^2)'' = -\frac{2x^2e^{2xw}(e^{xw}-2)}{(e^{xw}+1)^4}
(SquareLoss when y=1) = ( ( 1 1 1 + e x w ) 2 ) = 2 x 2 e x w ( 2 e x w 1 ) ( e x w + 1 ) 4 ((1-{\frac{1}{1+e^{-xw}}})^2)'' = \frac{2x^2e^{xw}(2e^{xw}-1)}{(e^{xw}+1)^4}
以上兩式皆有可能為負,所以它們並非凸函數。

因為SquareLoss 的二階導數可能為負
→MeanSquaredError的二階導數可能為負
→MeanSquaredError的一階導數可能會忽大忽小
→MeanSquaredError的一階導數可能會多次穿過零點
→MeanSquaredError可能會有多個局部最小值

結論

邏輯回歸使用交叉熵損失函數有兩個主要因素。
第一:交叉熵是其最大似然函數。
第二:使用均方誤差會導致代價函數空間有多個局部最小值,交叉熵函數則沒有這個問題。
而是否有多個局部最小值則是由該函數是否為凸函數來決定。

參考連結

desmos
Derivative Calculator
https://www.desmos.com/calculator/ls4kcjmeab
https://www.desmos.com/calculator/vqqkhehydi
產生文中例子的Python程式

其它邏輯回歸的文章:
為何說L1正則化會使得權重變得稀疏?
為何說邏輯回歸是線性模型?
為何邏輯回歸可以使用0來初始化,而神經網路不行?

猜你喜欢

转载自blog.csdn.net/keineahnung2345/article/details/83717856