感知机
感知机(perceptron)是支持向量机的和神经网络的基础。优点是简单易实现,缺点是只能分离线性可分数据集, 否则不收敛,且感知机学习得到的分离超平面不是唯一解,会因为初始值的选择以及选择训练样本的次序而改变。
核心思想
建立分离超平面,将正负实例点划分到超平面的两侧
模型
输入到输出的函数:
策略
损失函数最小化。很自然的想法就是误分类样本个数作为损失函数。但这对于参数w, b而言不易优化(不连续可导),故采用误分类样本点到分离超平面距离之和最小化为策略。没有误分类点时距离之和为0:
学习算法
随机梯度下降算法
算法流程
- Input: 训练数据集,学习率eta
- Output: 权值w, 阈值b
- Step1: 初始化权值w, b
- Step2: 随机梯度下降算法更新w, b,针对每个样本进行训练直到所有样本均分类正确
对偶形式感知机
模型
当初始化权值w为0向量。梯度下降更新权值 ,若每个样本学习了 次,那么:
将感知机中w的训练改为对 的训练即可训练对偶形式的感知机。
代码
"""
感知机(perceptron):原始形式以及对偶形式
"""
import numpy as np
class Perceptron:
def __init__(self, eta=1):
self.eta = eta # 学习率
self.w = None # 权值
self.b = None # 阈值
def fit(self, X_data, y_data):
self.w = np.zeros(X_data.shape[1]) # 初始化
self.b = 0
change = True
while change: # w, b 不发生改变则结束训练
for X, y in zip(X_data, y_data): # 依次输入每个数据点进行训练
change = False
while y * (self.w @ X + self.b) <= 0:
self.w += self.eta * X * y
self.b += self.eta * y
change = True
return
def predict(self, X):
return np.sign(self.w @ X + self.b)
class Perceptron_dual:
# 对偶形式的感知机
def __init__(self, eta=1):
self.eta = eta
self.alpha = None # alpha相当于样本的权值,当eta为1时就是每个样本参与训练的次数
self.b = None
self.N = None
self.gram = None
def init_param(self, X_data):
self.N = X_data.shape[0]
self.alpha = np.zeros(self.N)
self.b = 0
self.gram = self.getGram(X_data)
def getGram(self, X_data):
# 计算Gram矩阵
gram = np.diag(np.linalg.norm(X_data, axis=1) ** 2)
for i in range(self.N):
for j in range(i + 1, self.N):
gram[i, j] = X_data[i] @ X_data[j]
gram[j, i] = gram[i, j]
return gram
def sum_dual(self, y_data, i):
s = 0
for j in range(self.N):
s += self.alpha[j] * y_data[j] * self.gram[j][i]
return y_data[i] * (s + self.b)
def fit(self, X_data, y_data):
self.init_param(X_data)
changed = True
while changed:
changed = False
for i in range(self.N): # 依次输入每个数据点进行训练
while self.sum_dual(y_data, i) <= 0:
self.alpha[i] += self.eta
self.b += self.eta * y_data[i]
changed = True
return
if __name__ == '__main__':
X_data = np.array([[3, 3], [4, 3], [1, 1]])
y_data = np.array([1, 1, -1])
p = Perceptron()
p.fit(X_data, y_data)
print(p.w, p.b)