前言
代码可在Github上下载:代码下载
支持向量机,一个理论先于实践的产物,由Corinna Cortes和Vapnik等在1995年首次提出的,取得了非常好的效果。
SVM分为线性可分(又可分硬间隔与软间隔)和线性不可分的问题,线性不可分可通过核函数进行映射到空间里,从而线性可分。
线性可分支持向量机
模型:
函数间隔(关于样本点):
函数间隔(关于数据集T):
函数间隔表示一个确信度
几何间隔:
其实这里的几何间隔的就是点直线的距离。
由此我们可以得到
同时等价于
构造拉格朗日乘子
通过求导等于0课得
将
代入到
可得
我们的目标是
然后用对偶来表示,最终的对偶形式
但是有时数据难免有偏差(比如数据点不允许在间隔内),如果不允许这些数据点在间隔内,我们很难找到理想的间隔来划分数据,于是就有了软间隔最大化,C作为一个惩罚参数,C值大则对误分类的惩罚增大,反之同理。
可以看到软间隔最大化的公式主要差别在于
的取值大小。
到这步,我们需要求出
来得到
,如何求出
通常是选择SMO算法。
《统计学习方法》P124 7.4序列最小化优化算法开始专门为这部分进行了一个介绍。
首先这个算法最主要的是想求出
,有了
我们就能得到
和
。
那SMO的算法主要是分两步,P128有详细介绍,这里简单描述下,第一步,找出第一个
,这个可以找到一个不满足KKT条件的,进行优化,第二步,找出能使得
最大的数,这里可以通过对
进行判断,如果
是正的,就选个最小的
作为
,反之同理。
for i in range(N): #外循环
alpha1old = alpha[i].copy() #未更新的alpha,也就是alpha_old
y1 = labels[i] #y1
data1 = dataSet[i]
g1 = self.calcG(alpha, labels, data1, dataSet, b)
alphaIndex1 = -1 #存储不满足KKT条件的alphaIndex1
if alpha1old == 0: #判断是否满足KKT条件 (7.111)
if y1 * g1 < 1:
alphaIndex1 = i
flag = True
if alpha1old > 0 and alpha[i] < C: #(7.112)
if y1 * g1 != 1:
alphaIndex1 = i
flag = True
if alpha1old == C: #(7.1132)
if y1 * g1 <= 1:
alphaIndex1 = i
flag = True
if alphaIndex1 == -1: #说明满足KKT条件,继续下一次循环来找alpha1
continue
E1 = g1 - y1 #(7.105)
那首先有个外循环,来遍历每一个
来找第一个
,书本说的是选取违反KKT最严重的样本点,那这里做了简单处理,只要是违反的就行,并且算出
。
其中E的计算方式为
。
for j in range(N): #内循环
if i != j: #相等就没法选了
yj = labels[j]
gj = self.calcG(alpha, labels, dataSet[j], dataSet, b)
Ej = gj - yj
if E1 > 0: #说明要选最小的E2
if Ej < selectedE2:
selectedE2 = Ej
alphaIndex2 = j
else:
if Ej > selectedE2:
selectedE2 = Ej
alphaIndex2 = j
以上的代码就是寻找出第二个 。
L = 0 # P126末尾两段
H = 0
y2 = labels[alphaIndex2]
alpha2old = alpha[alphaIndex2].copy()
data2 = dataSet[alphaIndex2]
E2 = selectedE2
if (y1 == y2): #alpha2取值范围必须限制在L<alpha2<H
L = np.maximum(0, alpha2old - alpha1old) # L
H = np.minimum(C, C + alpha2old - alpha1old) #H
else:
L = np.maximum(0, alpha2old + alpha1old - C) # L
H = np.minimum(C, C + alpha2old + alpha1old) #H
eta = self.calcK(data1, data1) + self.calcK(data2, data2) - 2 * self.calcK(data1, data2)
if eta == 0: #没法选
continue
alpha2new = alpha2old + (y2 * (E1 - E2)) / eta
if (alpha2new > H): # (7.108)
alpha[alphaIndex2] = H
elif (alpha2new < L):
alpha[alphaIndex2] = L
else:
alpha[alphaIndex2] = alpha2new
alpha1new = alpha1old * y1 * y2 * (alpha2old - alpha2new) #(7.109)
alpha[alphaIndex1] = alpha1new
OK,有了两个
,我们可以根据公式
(7.106),其中来得出一个未经处理的候选
,这个
由于有限制范围,需要处理下。
具体处理方法为(7.108):
有了
之后,,
b1new = -E1 - y1 * self.calcK(data1, data1) * (alpha1new - alpha1old) - y2 * self.calcK(data2, data1) * (alpha2new - alpha2old) + b #(7.115)
b2new = -E2 - y1 * self.calcK(data1, data1) * (alpha1new - alpha1old) - y2 * self.calcK(data2, data2) * (alpha2new - alpha2old) + b #(7.116)
if (alpha1new > 0 and alpha1new < C):
b = b1new
else:
b = (b1new + b2new) / 2
同时,由于计算过程中需要用到
,所以需要在此过程中更新
公式在P130页(7.115)和(7.116),
然后如果
和
同时满足
,那么
或者
,如果不符合,则选这两个值的中值。
weights = np.dot(np.multiply(alpha, labels), dataSet) #权重
至此,
全部都求出来了,我们通过
可以得到
。
就是刚才那个
了。
这里有个小插曲,
值也可以在每次得到两个
值后,通过
。
这个代码在python3环境下运行。
def loadDataSet(): #加载文件
data = list()
labels = list()
with open('testSet.txt') as f:
lines = f.readlines()
for line in lines:
line = line.rstrip().split('\t')
data.append([float(line[0]), float(line[1])])
labels.append(float(line[-1]))
return data, labels
def sign(x): #符号函数
if x >= 0:
return 1
else:
return -1
class SVM:
def train(self, dataSet, labels): #训练并返回权重和偏置
b = 0 #偏置
C = 1 #惩罚系数
flag = True #检验是否全部都满足KKT条件
maxIter = 100 #最大循环次数
iter = 0
N = len(dataSet) #数据的行数
M = len(dataSet[0]) #数据的列数,维数
alpha = np.zeros(N)
while iter < maxIter:
print(iter)
iter += 1
flag = False
for i in range(N): #外循环
alpha1old = alpha[i].copy() #未更新的alpha,也就是alpha_old
y1 = labels[i] #y1
data1 = dataSet[i]
g1 = self.calcG(alpha, labels, data1, dataSet, b)
alphaIndex1 = -1 #存储不满足KKT条件的alphaIndex1
if alpha1old == 0: #判断是否满足KKT条件 (7.111)
if y1 * g1 < 1:
alphaIndex1 = i
flag = True
if alpha1old > 0 and alpha[i] < C: #(7.112)
if y1 * g1 != 1:
alphaIndex1 = i
flag = True
if alpha1old == C: #(7.1132)
if y1 * g1 <= 1:
alphaIndex1 = i
flag = True
if alphaIndex1 == -1: #说明满足KKT条件,继续下一次循环来找alpha1
continue
E1 = g1 - y1 #(7.105)
alphaIndex2 = -1
if E1 > 0: #正的话要找E2的最小值,反之同理
selectedE2 = np.inf
else:
selectedE2 = -np.inf
for j in range(N): #内循环
if i != j: #相等就没法选了
yj = labels[j]
gj = self.calcG(alpha, labels, dataSet[j], dataSet, b)
Ej = gj - yj
if E1 > 0: #说明要选最小的E2
if Ej < selectedE2:
selectedE2 = Ej
alphaIndex2 = j
else:
if Ej > selectedE2:
selectedE2 = Ej
alphaIndex2 = j
'''
此时应该选到了alpha2了
'''
L = 0 # P126末尾两段
H = 0
y2 = labels[alphaIndex2]
alpha2old = alpha[alphaIndex2].copy()
data2 = dataSet[alphaIndex2]
E2 = selectedE2
if (y1 == y2): #alpha2取值范围必须限制在L<alpha2<H
L = np.maximum(0, alpha2old - alpha1old) # L
H = np.minimum(C, C + alpha2old - alpha1old) #H
else:
L = np.maximum(0, alpha2old + alpha1old - C) # L
H = np.minimum(C, C + alpha2old + alpha1old) #H
eta = self.calcK(data1, data1) + self.calcK(data2, data2) - 2 * self.calcK(data1, data2)
if eta == 0: #没法选
continue
alpha2new = alpha2old + (y2 * (E1 - E2)) / eta
if (alpha2new > H): # (7.108)
alpha[alphaIndex2] = H
elif (alpha2new < L):
alpha[alphaIndex2] = L
else:
alpha[alphaIndex2] = alpha2new
alpha1new = alpha1old * y1 * y2 * (alpha2old - alpha2new) #(7.109)
alpha[alphaIndex1] = alpha1new
b1new = -E1 - y1 * self.calcK(data1, data1) * (alpha1new - alpha1old) - y2 * self.calcK(data2, data1) * (alpha2new - alpha2old) + b #(7.115)
b2new = -E2 - y1 * self.calcK(data1, data1) * (alpha1new - alpha1old) - y2 * self.calcK(data2, data2) * (alpha2new - alpha2old) + b #(7.116)
if (alpha1new > 0 and alpha1new < C):
b = b1new
else:
b = (b1new + b2new) / 2
print(alpha)
weights = np.dot(np.multiply(alpha, labels), dataSet) #权重
return weights, b
def calcK(self, data1, data2): #线性核函数,返回内积
return np.dot(data1, data2)
def calcG(self, alpha, labels, data, dataSet, b): #计算g
sum = 0
for j in range(len(alpha)):
sum += alpha[j] * labels[j] * self.calcK(data, dataSet[j]) #g(x)的计算
return sum + b
if __name__ == '__main__':
dataSet, labels = loadDataSet()
svm = SVM()
weights, b = svm.train(dataSet, labels)
print(weights, b)
x = [1, 2]
f = sign(np.dot(weights, x) + b)
print(f)
以下是数据集,请保存为testSet.txt文件放在同目录下。
3.542485 1.977398 -1
3.018896 2.556416 -1
7.551510 -1.580030 1
2.114999 -0.004466 -1
8.127113 1.274372 1
7.108772 -0.986906 1
8.610639 2.046708 1
2.326297 0.265213 -1
3.634009 1.730537 -1
0.341367 -0.894998 -1
3.125951 0.293251 -1
2.123252 -0.783563 -1
0.887835 -2.797792 -1
7.139979 -2.329896 1
1.696414 -1.212496 -1
8.117032 0.623493 1
8.497162 -0.266649 1
4.658191 3.507396 -1
8.197181 1.545132 1
1.208047 0.213100 -1
1.928486 -0.321870 -1
2.175808 -0.014527 -1
7.886608 0.461755 1
3.223038 -0.552392 -1
3.628502 2.190585 -1
7.407860 -0.121961 1
7.286357 0.251077 1
2.301095 -0.533988 -1
-0.232542 -0.547690 -1
3.457096 -0.082216 -1
3.023938 -0.057392 -1
8.015003 0.885325 1
8.991748 0.923154 1
7.916831 -1.781735 1
7.616862 -0.217958 1
2.450939 0.744967 -1
7.270337 -2.507834 1
1.749721 -0.961902 -1
1.803111 -0.176349 -1
8.804461 3.044301 1
1.231257 -0.568573 -1
2.074915 1.410550 -1
-0.743036 -1.736103 -1
3.536555 3.964960 -1
8.410143 0.025606 1
7.382988 -0.478764 1
6.960661 -0.245353 1
8.234460 0.701868 1
8.168618 -0.903835 1
1.534187 -0.622492 -1
9.229518 2.066088 1
7.886242 0.191813 1
2.893743 -1.643468 -1
1.870457 -1.040420 -1
5.286862 -2.358286 1
6.080573 0.418886 1
2.544314 1.714165 -1
6.016004 -3.753712 1
0.926310 -0.564359 -1
0.870296 -0.109952 -1
2.369345 1.375695 -1
1.363782 -0.254082 -1
7.279460 -0.189572 1
1.896005 0.515080 -1
8.102154 -0.603875 1
2.529893 0.662657 -1
1.963874 -0.365233 -1
8.132048 0.785914 1
8.245938 0.372366 1
6.543888 0.433164 1
-0.236713 -5.766721 -1
8.112593 0.295839 1
9.803425 1.495167 1
1.497407 -0.552916 -1
1.336267 -1.632889 -1
9.205805 -0.586480 1
1.966279 -1.840439 -1
8.398012 1.584918 1
7.239953 -1.764292 1
7.556201 0.241185 1
9.015509 0.345019 1
8.266085 -0.230977 1
8.545620 2.788799 1
9.295969 1.346332 1
2.404234 0.570278 -1
2.037772 0.021919 -1
1.727631 -0.453143 -1
1.979395 -0.050773 -1
8.092288 -1.372433 1
1.667645 0.239204 -1
9.854303 1.365116 1
7.921057 -1.327587 1
8.500757 1.492372 1
1.339746 -0.291183 -1
3.107511 0.758367 -1
2.609525 0.902979 -1
3.263585 1.367898 -1
2.912122 -0.202359 -1
1.731786 0.589096 -1
2.387003 1.573131 -1