文章目录
SMO算法描述:
代码实现
确定第一个变量:
# 找第一个样本
def _pick_first(self,tol):
# mask,漏出需要的元素
con1 = self._alpha >0
con2 = self._alpha <self._c
# 复制成3分
# 预测值通过局部更新方式更新
err1 = self._y *self._prediction_cache -1
err2 = err1.copy()
err3 = err1.copy()
# 三种情况的处理
err1[(con1 & (err1 <= 0)) | (~con1 & (err1 > 0))] = 0
err2[((~con1 | ~con2) & (err2 != 0)) | ((con1 & con2) & (err2 == 0))] = 0
err3[(con2 & (err3 >= 0)) | (~con2 & (err3 < 0))] = 0
# 算出损失最大的那个
err = err1**2 + err2**2 + err3**2
idx = np.argmax(err)
# 如果该项的损失等于0,返回,否则返回选取的下标
if err[idx] < tol:
return
return idx
确定第二个变量:
# 找第二个样本,这里实现的是随机挑选
def _pick_second(self,idx1):
idx = np.random.randint(len(self._y))
while idx == idx1:
idx = np.random.randint(len(self._y))
return idx
更新alpha值:
def _update_alpha(self,idx1,idx2):
l,h = self._get_lower_bound(idx1, idx2), self._get_upper_bound(idx1, idx2)
y1, y2 = self._y[idx1],self._y[idx2]
e1 = self._prediction_cache[idx1] - y1
e1 = self._prediction_cache[idx2] - y2
# dK = K11 + K22 -2K12
eta = self._gram[idx1][idx1] + self._gram[idx2][idx2] - 2*self._gram[idx1][idx2]
a2_new = self._alpha[idx2] + (y2*(e1-e2)) / eta
# 约束条件约束
if a2_new > h:
a2_new = h
elif a2_new < l:
a2_new = l
a1_old, a2_old = self._alpha[idx1],self._alpha[idx2]
da2 = a2_new - a2_old
da1 = -y1*y2*da2
# 更新alpha
self._alpha[idx1] +=da1
self._alpha[idx2] = a2_new
# 根据da来更新dw,db,y^
self._update_dw_cache(idx1, idx2, da1, da2, y1, y2)
self._update_dw_cache(idx1, idx2, da1, da2, y1, y2, e1, e2)
# 注意参数传入的方法
self._update_pred_cache(idx1, idx2)
更新dw,db,y^,注意参数传入的方法:
# 更新dw
def _update_dw_cache(self, idx1, idx2, da1, da2, y1, y2):
self._dw_cache = np.array([da1*y1,da2*y2])
# 更新b
def _updata_b_cache(self, idx1, idx2, da1, da2, y1, y2, e1, e2):
gram_12 = self._gram[idx1][idx2]
b1 = -e1 - y1 * self._gram[idx1][idx1] * da1 - y2 * gram_12 * da2
b2 = -e2 - y1 * gram_12 * da1 - y2 * self._gram[idx2][idx2] * da2
# 这里分两种情况,两种情况下都是(b1 + b2) * 0.5
self._db_cache = (b1 + b2) * 0.5
self._b += self._db_cache
# y^ = y^ + dw1K1 + dw2k2 + db ,注意参数传入的方法
def _update_pred_cache(self,*args):
self._prediction_cache += self._db_cache
if len(args) == 1:
self._prediction_cache += self._dw_cache * self._gram[args[0]]
else:
self._prediction_cache += self._dw_cache.dot(self._gram[args, ...])
主要处理程序(一):一个回合的处理
def _fit(self,sample_weight, tol):
idx1 = self._pick_first(tol)
# 如果所有的样本误差均小于阈值
if idx1 is None:
return True
idx2 = self._pick_second(idx1)
# 更新
self._update_alpha(idx1,idx2)
主要处理程序(二):整体epoch处理
def fit(self, x, y, kernel="rbf", epoch=10**4, **kwargs):
self._x, self._y = np.atleast_2d(x), np.array(y)
if kernel == "poly":
# 对于多项式,默认使用kernelconfig中的default_p作为p的值
_p = kwargs.get("p", KernelConfig.default_p)
self._kernel_name = "Polynomial"
self._kernel_param = "degree = {}".format(_p)
self._kernel = lambda _x, _y : KernelBase._poly(_x,_y,_p)
elif kernel == "rbf":
# 对于rbf使用属性个数的倒数为gamma
_gamma = kwargs.get("gamma", 1 / self._x.shape[1])
self._kernel_name = "RBF"
self._kernel_param = "gamma = {}".format(_gamma)
self._kernel = lambda _x, _y : KernelBase._rbf(_x,_y,_gamma)
# 初始化参数
self._alpha = np.zeros(len(x))
self._w = np.zeros(len(x))
self._prediction_cache = np.zeros(len(x))
self._gram = self._kernel(self._x, self._x)
self._b = 0
# 调用——prepare方法进行特殊参数的初始化
self._prepare(**kwargs)
# 获取在循环体中会用到的参数
_fit_args = []
for _name, _args in zip(self._fit_args_names, self._fit_args):
if _name in kwargs:
_arg = kwargs[_name]
_fit_args.append(_arg)
for _ in range(epoch):
# 如果所有的样本误差均足够下,退出训练
if self._fit(sample_weight, *_fit_args):
break
self._update_params()
其他辅助函数:
# 定义计算多项式核矩阵函数
@staticmethod
def _poly(x, y, p):
return (x.dot(y.T) + 1)**p
# 定义计算RBF核矩阵函数
# x[...,None,:]为升维
@staticmethod
def _rbf(x, y, gamma):
return np.exp(-gamma*np.sum((x[...,None,:] - y)**2), axis=2)
class KernelConfig:
default_p = 3
default_c = 1
# 更新w,b,这里的w不同于我们的w,是w . K + b的w
def _update_params(self):
self._w = self._alpha * self._y
# 取O·C最大的a
_idx = np.argmax((self._alpha !=0) & (self._alpha != self._c))
self._b = self._y[idx] - np.sum(self._alpha * self._y * self._gram[_idx])
# 初始化惩罚因子
def _prepare(self,**kwargs):
self._c = kwargs.get("c", KernelConfig.default_c)