Multi-class Logistic Regression
在
Logistic Regression摘记 一文中对二元Logistic回归进行了详细的介绍,本文主要描述采用
softmax 函数实现多元Logistic回归。
1. softmax函数
对于某个输入
x,其对应的
softmax 输出为向量值
y=[y1,⋯,yk,⋯,yK]T,且满足
k=1∑Kyk=1。
(1) 分类问题
(K>2) 中使用
softmax 函数表示输出值分量:
yk=j=1∑Keajeak=j=1∑KewjTx+bjewkTx+bk , k=1,2,⋯,K
(2) 对于多元Logistic回归
:
aj=wjTx+bj
其中,
x∈Rd,wj=[wj,1,wj,2,⋯,wj,d]T (j=1,2,⋯,K)
若记
wj∗=[wjT,bj]T 和
x∗=[xT,1]T,则
aj=wjTx+bj=(wj∗)Tx∗
【为了方便描述】可以略掉 ‘
∗’ 号,直接写成:
aj=wjTx
(3) 将输出值分量
yk 描述成后验概率的形式:
yk=p(y=k∣x)=j=1∑KewjTxewkTx , k=1,2,⋯,K
2. 与二元Logistic回归的关系
对比二元Logistic回归
-
x为正例的概率:
p(y=1∣x)=1+e−(wTx+b)1
-
x为负例的概率:
p(y=0∣x)=1−1+e−(wTx+b)1=1+e(wTx+b)e−(wTx+b)=1+e−(wTx+b)1
当
K=2 时,softmax函数实际上等同于二元Logistic回归(假设
k={0,1}):
⎩⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎧ y0=p(y=0∣x)=ew0Tx+ew1Txew0Tx=1+e(w1−w0)Tx1 y1=p(y=1∣x)=ew0Tx+ew1Txew1Tx=1+e−(w1−w0)Tx1
令
w^=w1−w0,那么类后验概率就是二元Logistic回归中情形。
3. 误差函数
针对多元Logistic回归
,首先要写出其误差函数。
假设训练样本集为
{(xn,cn)}n=1N,其中
xn∈Rd,cn∈{1,2,⋯,K},参数为
W=(w1T,w2T,⋯,wKT,b)T。
二元Logistic回归
假设训练样本为
{(xn,cn)}n=1N,其中
xn∈Rd,cn∈{0,1},似然函数为:
L(w,b)=n=1∏Nh(xn)cn[1−h(xn)]1−cn , h(x)=p(c=1∣x)=1+e−(wTx+b)1
取“负的对数似然函数”作为误差函数,即:
l(w,b)=−lnL(w,b)。
3.1 多元回归的1-of-K表示
(1) 用变量
c∈{1,2,⋯,K} 表示输入
x 所对应的类别
(2) 引入目标向量
t=[0,⋯,0,1,0,⋯,0]T∈RK,满足
tk=1,tj=0 (j=k)
表示“输入
x 属于第
k 类” 或者说变量
c=k
(3) 用向量值
y=[y1,⋯,yk,⋯,yK]T 表示输入
x 所对应的
softmax输出
yk=p(c=k∣x)=p(tk=1∣x)=k=1∏Kp(tk∣x)tk=j=1∑KewjTxewkTx
显然,
k=1∑Kyk=k=1∑Kp(y=k∣x)=1
3.2 训练样本集的似然函数
(1) 对于第
n 个训练样本
(xn,cn),其
softmax 输出为
yn=[yn1,⋯,ynk,⋯,ynK]T,且
ynk=p(cn=k∣xn)=p(tnk=1∣xn)=k=1∏Kp(tnk∣xn)tnk
(2) 训练样本集
{(xn,cn)}n=1N 的似然函数
L(W) 为:
L(W)=n=1∏Np(cn=k∣xn)=n=1∏Np(tnk=1∣xn)=n=1∏Nk=1∏Kp(tnk∣xn)tnk
3.3 交叉熵误差函数
取“负的对数似然函数”为(交叉熵)误差函数
(cross−entropy error function)
l(W)=−lnL(W)=−n=1∑Nk=1∑Ktnklnp(tnk∣xn)=−n=1∑Nk=1∑Ktnklnynk
使用交叉熵作为误差函数,是因为:
(1) 若训练样本
xn 的类别
cn=k,则对应的目标向量
tn 只有第
k 个分量
tnk=1,而其他分量
tnj=0 (j=k)。
(2) 在训练过程中,
ynk 是训练样本
xn 所对应
softmax 输出的第
k 个分量(训练样本的正确类别
k 所对应的输出分量值)。
(3) 如果正确类别
k 所对应分量值
ynk 越大,
lnynk 也越大,交叉熵就越小,训练误差也就越小。
(4) 理想情况下,正确类别
k 所对应分量值
ynk=1,∀ n,那么交叉熵为
0,也就是没有训练误差。
4. 最大似然估计
为了求出参数
W=(w1T,w2T,⋯,wKT,b)T,同样采用最大似然估计。
可以将训练样本集分成
K 个子集
C1,⋯,Ck,⋯,CK,第
k 个子集
Ck 中的所有样本
xn 的类别都为
cn=k,对应的目标向量
tn 都满足
tnk=1,tnj=0 (j=k),由误差函数的表达式:
l(W)=−n=1∑Nk=1∑Ktnklnynk=−n∈C1∑tn1lnyn1−⋯−n∈Ck∑tnklnynk−⋯−n∈CK∑tnKlnynK=−n∈C1∑lnyn1−⋯−n∈Ck∑lnynk−⋯−n∈CK∑lnynK
对
l(W) 求参数
wk 的偏导分为两个部分:
(1) 对
l(W) 的第
k 个分量
lk(W)=−n∈Ck∑lnynk 求参数
wk 的偏导
∂wk∂lk(W)=−n∈Ck∑ynk1∂wk∂ynk(wkTxn=xnTwk)=−n∈Ck∑ynk1(j=1∑KewjTxn)2(ewkTxn)′j=1∑KewjTxn−ewkTxn(j=1∑KewjTxn)′=−n∈Ck∑ynk1(j=1∑KewjTxn)2(ewkTxn)xnj=1∑KewjTxn−ewkTxn(ewkTxn)xn=−n∈Ck∑ynk1j=1∑KewjTxn(ewkTxn)xnj=1∑KewjTxnj=1∑KewjTxn−ewkTxn=−n∈Ck∑ynk1ynkxn(1−ynk)=−n∈Ck∑(1−ynk)xn
(2) 对
l(W) 的第
i (i=k) 个分量
li(W)=−n∈Ci∑lnyni 求参数
wk 的偏导
∂wk∂li(W)=−n∈Ci∑yni1∂wk∂yni=−n∈Ci∑yni1(j=1∑KewjTxn)2(ewiTxn)′j=1∑KewjTxn−ewiTxn(j=1∑KewjTxn)′=−n∈Ci∑yni1(j=1∑KewjTxn)2−ewiTxn(ewkTxn)xn=−n∈Ci∑yni1j=1∑KewjTxnewiTxnj=1∑KewjTxn−(ewkTxn)xn=−n∈Ci∑yik1yni(−ynk)xn=−n∈Ci∑(−ynk)xn
综合起来,两个公式可以表示为:
∂wk∂l(W)=−n=1∑N(tnk−ynk)xn
采用梯度下降法时,权值更新公式为:
W(m+1)=W(m)−α∂W∂l(W)
其中
α 为梯度下降法的步长。
代码实现(mnist数据集)
import numpy as np
from dataset.mnist import load_mnist
def softmax_train(train,target,alpha,num):
xhat = np.concatenate((train,np.ones((len(train),1))),axis=1)
nparam = len(xhat.T)
beta = np.random.rand(nparam,10)
for i in range(num):
wtx = np.dot(xhat,beta)
wtx1 = wtx - np.max(wtx,axis=1).reshape(len(train),1)
e_wtx = np.exp(wtx1)
yx = e_wtx/np.sum(e_wtx,axis=1).reshape(len(xhat),1)
print(' #'+str(i+1)+' : '+str(cross_entropy(yx,target)))
t1 = target - yx
t2 = np.dot(xhat.T, t1)
beta = beta + alpha*t2
return beta
def cross_entropy(yx,t):
sum1 = np.sum(yx*t,axis=1)
ewx = np.log(sum1+0.000001)
return -np.sum(ewx)/len(yx)
def classification(test, beta, test_t):
xhat = np.concatenate((test,np.ones((len(test),1))),axis=1)
wtx = np.dot(xhat,beta)
output = np.where(wtx==np.max(wtx,axis=1).reshape((len(test),1)))[1]
print("Percentage Correct: ",np.where(output==test_t)[0].shape[0]/len(test))
return np.array(output,dtype=np.uint8)
if __name__ == '__main__':
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
nread = 60000
train_in = x_train[:nread,:]
train_tgt = np.zeros((nread,10))
test_in = x_test[:10000,:]
test_t = t_test[:10000]
for i in range(nread):
train_tgt[i,t_train[i]] = 1
beta = softmax_train(train_in,train_tgt,0.001,60)
print(beta)
result = classification(test_in, beta, test_t)
测试结果:
#1 : 5.626381119337011
#2 : 5.415158063701459
#3 : 10.959830171565791
#4 : 8.062787294189338
#5 : 7.4643357380759765
#6 : 9.070059164063883
#7 : 9.81079287953052
#8 : 7.13921201579068
#9 : 7.176904417794094
#10 : 4.607102717465571
#11 : 3.9215536116316625
#12 : 4.199011112147004
#13 : 4.135313269465135
#14 : 3.214738972020379
#15 : 2.804664146283606
#16 : 2.901161881757491
#17 : 2.9996749271603456
#18 : 2.609904566490558
#19 : 2.6169338357951197
#20 : 2.538795429964946
#21 : 2.7159497447897256
#22 : 2.634980803678192
#23 : 2.974848646434367
#24 : 3.1286179795674154
#25 : 3.2208869228881407
#26 : 2.548910343301664
#27 : 2.5298981152704743
#28 : 2.3826001247525035
#29 : 2.4498572463653243
#30 : 2.3521370651353837
#31 : 2.4309032741212664
#32 : 2.366133209606206
#33 : 2.4462922376053364
#34 : 2.3850487760328933
#35 : 2.4481429887352792
#36 : 2.370067560256672
#37 : 2.376729198498193
#38 : 2.297488373847759
#39 : 2.265126273640295
#40 : 2.258495714414137
#41 : 2.327524884607823
#42 : 2.3130200962416128
#43 : 2.290046983208286
#44 : 2.1465196716967805
#45 : 2.0969060851949677
#46 : 1.8901858209971119
#47 : 1.844354795879705
#48 : 1.6340799726564934
#49 : 1.60064459794013
#50 : 1.4667008762515674
#51 : 1.4453938385590863
#52 : 1.3767004735390218
#53 : 1.359619935503484
#54 : 1.3153462460865966
#55 : 1.309895715988472
#56 : 1.2799649790773286
#57 : 1.2807586745656392
#58 : 1.2559139323742572
#59 : 1.2582212637839076
#60 : 1.237819660093416
权值:
[[7.69666472e-01 2.16009202e-01 9.81729719e-01 … 5.32453082e-01
7.88719040e-01 5.14326954e-01]
[3.90401951e-01 5.84040914e-01 7.94883641e-01 … 8.02009249e-01
3.29345264e-02 6.70861290e-01]
[8.69075434e-02 8.43381782e-01 4.77683466e-01 … 8.71965798e-01
4.47018470e-04 5.07498017e-01]
…
[7.96129468e-01 6.14364951e-01 8.32783158e-01 … 6.53493763e-01
2.06235991e-01 8.60469591e-01]
[1.67070291e-01 3.23211147e-02 2.41519794e-01 … 6.56026583e-01
5.98396521e-01 5.42304452e-01]
[8.43299673e-01 6.22843596e-01 6.05652099e-02 … 1.10339403e-01
1.61855811e-01 3.29385438e-01]]
识别率:
Percentage Correct: 0.9037