梯度下降法实现逻辑回归(python 代码)

逻辑回归很好的学习博客
在逻辑回归中,损失函数的定义为最大似然估计,也就是所有样本判断正确的可能性相乘。
h ( x i ) h(x_i) h(xi)表示对输入的 x x x判断为1的概率
一般我们定义 h ( x i ) = 1 1 + e − z h(x_i)=\frac{1}{1+e^{-z}} h(xi)=1+ez1,其中z为 x ∗ θ T x*\theta^T xθT x x x为输入参数, θ \theta θ为当前求出的参数。为什么用这个函数可以看这里

def sigmoid(z):
    return 1/(1+np.exp(-z))

c o s t ( x , y , θ ) = Π i = 1 m h ( x i ) y i ∗ ( 1 − h ( x i ) ) 1 − y i cost(x,y,\theta)=\Pi_{i=1}^mh(x_i)^{y_i}*(1-h(x_i))^{1-y_i} cost(x,y,θ)=Πi=1mh(xi)yi(1h(xi))1yi
因为乘法和指数运算在矩阵中不方便进行,所以对两边取log转换成加法运算:
l n ( c o s t ( x , y , θ ) ) = ∑ i = 1 m y i ∗ l n ( h ( x i ) ) + ( 1 − y i ) ∗ l n ( 1 − h ( x i ) ) ln(cost(x,y,\theta))=\sum_{i=1}^{m}y_i*ln(h(x_i))+(1-y_i)*ln(1-h(x_i)) ln(cost(x,y,θ))=i=1myiln(h(xi))+(1yi)ln(1h(xi))
因为对数函数是单调的,我们要求的是相对的最值,所以可以直接把 l n ( c o s t ( x , y ) ) ln(cost(x,y)) ln(cost(x,y))作为比较的度量。因为一般都是求最小值,所以让这个度量值取个负。一般来说再除个m代表平均,最终 f ( x , y , θ ) = − l n ( c o s t ( x , y , θ ) ) / 2 m f(x,y,\theta)=-ln(cost(x,y,\theta))/2m f(x,y,θ)=ln(cost(x,y,θ))/2m就是损失函数的返回值。

def cost_func(theta, x, y):
    x = np.matrix(x)
    y = np.matrix(y)
    theta = np.matrix(theta)
    z = x*theta.T
    pos = np.multiply(y, np.log(sigmoid(z)))
    neg = np.multiply(1-y, np.log(1-sigmoid(z)))
    return np.sum(-pos-neg)/len(x)

而对这个函数求偏导之后就可以得到梯度。
求偏导的过程有些复杂:
δ f ( x , y , θ ) δ θ = − 1 m ( y h ( x ) h ′ ( x ) − 1 − y 1 − h ( x ) h ′ ( x ) ) = − 1 m ( y − h ( x ) h ( x ) ( 1 − h ( x ) ) h ′ ( x ) = − 1 m ( y − h ( x ) h ( x ) ( 1 − h ( x ) ) ∗ h ( x ) ∗ ( 1 − h ( x ) ) ∗ z ′ = ( h ( x ) − y ) ∗ x m \frac{\delta f(x,y,\theta)}{\delta \theta}=-\frac{1}{m}(\frac{y}{h(x)}h'(x)-\frac{1-y}{1-h(x)}h'(x))\\=-\frac{1}{m}(\frac{y-h(x)}{h(x)(1-h(x)})h'(x)\\=-\frac{1}{m}(\frac{y-h(x)}{h(x)(1-h(x)})*h(x)*(1-h(x))*z'\\=\frac{(h(x)-y)*x}{m} δθδf(x,y,θ)=m1(h(x)yh(x)1h(x)1yh(x))=m1(h(x)(1h(x)yh(x))h(x)=m1(h(x)(1h(x)yh(x))h(x)(1h(x))z=m(h(x)y)x
最终的结果形式上和多元线性回归的结果很像,多元线性回归中的 x θ T x\theta^T xθT换成了 h ( x ) = 1 1 + e x θ T h(x)=\frac{1}{1+e^{x\theta^T}} h(x)=1+exθT1
求梯度的函数:

def grident(theta, x, y):
    theta = np.matrix(theta)
    x = np.matrix(x)
    y = np.matrix(y)
    para_num = x.shape[1]
    grad = np.zeros(para_num)
    error = sigmoid(x*theta.T)-y
    for i in range(para_num):
        term = np.multiply(error, x[:, i])
        grad[i] = np.sum(term)/len(x)
    return grad

迭代函数

def f(x, y, theta, alpha, iters):
    theta = np.matrix(theta)
    x = np.matrix(x)
    y = np.matrix(y)
    for i in range(iters):
        theta = theta - alpha*grident(theta, x, y)
    return theta

预测的时候设置一个可能性阈值,这里设置为如果为1的可能性大于等于0.5则为1

def predict(theta, x):
    probability = sigmoid(x*theta.T)
    return [1 if x >= 0.5 else 0 for x in probability]

梯度下降一般只能得到局部最优,最终的结果和初始点的选取有很大关系。

数据集的形式为每一行3个参数,三个参数用逗号隔开 x1, x2, y

完整代码

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def sigmoid(z):
    return 1/(1+np.exp(-z))
def cost_func(theta, x, y):
    x = np.matrix(x)
    y = np.matrix(y)
    theta = np.matrix(theta)
    z = x*theta.T
    pos = np.multiply(y, np.log(sigmoid(z)))
    neg = np.multiply(1-y, np.log(1-sigmoid(z)))
    return np.sum(-pos-neg)/len(x)
def grident(theta, x, y):
    theta = np.matrix(theta)
    x = np.matrix(x)
    y = np.matrix(y)
    para_num = x.shape[1]
    grad = np.zeros(para_num)
    error = sigmoid(x*theta.T)-y
    for i in range(para_num):
        term = np.multiply(error, x[:, i])
        grad[i] = np.sum(term)/len(x)
    return grad
def f(x, y, theta, alpha, iters):
    theta = np.matrix(theta)
    x = np.matrix(x)
    y = np.matrix(y)
    for i in range(iters):
        theta = theta - alpha*grident(theta, x, y)
    return theta
def predict(theta, x):
    probability = sigmoid(x*theta.T)
    return [1 if x >= 0.5 else 0 for x in probability]
path = os.getcwd()
data = pd.read_csv(path+'\\data\\ex2data1.txt', names = ['exam1', 'exam2', 'admit'])

data.insert(0, 'Ones', 1)
cols = data.shape[1]
x = data.iloc[:,:cols-1]
x = np.array(x.values)
y = data.iloc[:,cols-1:cols]
y = np.array(y.values)
th = np.zeros(3)
#import scipy.optimize as opt
#result = opt.fmin_tnc(func = cost_func, x0 = th, fprime= grident, args = (x, y))

th = np.array([-20,0.21,0.2])
th = f(x, y, th, 0.0001, 1000)
print(cost_func(th, x, y))
cnt = 0
predictions = predict(th, x)
for i in range(len(x)):
    if predictions[i] == y[i]:
        cnt+=1
print("正确率:",cnt/len(x)*100,'%%')

训练数据

来自https://www.johnwittenauer.net/machine-learning-exercises-in-python-part-3/
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1

猜你喜欢

转载自blog.csdn.net/qq_43202683/article/details/104734894