感知机(perceptron) 学习笔记

前言:偶尔回想自己学过的算法,反复看反复忘,故又重复看一遍并记下笔记,供后续学习参考。


感知机是一个二分类算法,是深度学习的简化版,只有一层网络,建模思想跟支持向量机类似,是两算法的基础。

分类原理 :y =(编辑公式还有更好的办法吗?), 满足wx+b>=0的输入被分类为标签1,否则被分类为标签-1。

建模:分类的是超平面wx+b=0,以输入点到超平面的距离为判断标准。距离公式为d = |wx+b|/ ||w|| = y(wx+b)/||w||。||.||是二范数。

损失函数:损失的自然选择是错误分类的总点数,但是这样的损失不是参数w,b的连续可到函数,不易优化,所以将错误分类的总点数到超平面的距离的总距离定义为损失函数,因此损失函数deltaL =  -1*y(wx+b)/||w||,此处根据SVM中几何间隔,知需要对w进行约束,避免b与w同比例增加,参数变了但是超平面本身并没有变动,所以约定||w||=1,损失函数为deltaL =  -1*y(wx+b),更细致的推理可参见其他资料,(我用||w||计算损失进行训练没见特别冥想的错误?)。

优化算法:优化算法是为了求得最优解,这里使用随机梯度下降算法SGD,已知损失函数,根据其对变量w,b求导得到,dw = -yx, db = -y,设学习率为u = 0.5。w= w-dw = w+uyx, b = b - db = b+uy

代码如下:

import math
import random
'''
感知机是二分类模型
判断条件 y = 1 when w1x1+w2x2+b>0
        y = -1 when w1x1+w2x2+b<0

'''

#z准备数据集
X = [[3,2], [12,10], [33,62], [8,16], [23,45], [7,13], [78,65], [35,54], [77,55], [89,23]]
Y = [-1, -1, 1, -1, -1, -1,1, -1, 1, 1]

#训练网络/推理逻辑
# y= x1*w1 + x2*w2 +b

w1 = 0.1
w2 = 0.1
b = 0
u = 0.5
# for [x1,x2] in X:
#     y = x1*w1 + x2*w2 +b


#计算损失函数
#点到平面wx+b=0 的距离 d = |wx+b| / ||w||,对于误分类的则需要被统计作为损失,目标是使损失变为0
#d = -yi*(wxi+b)/math.sqrt(sum(pow(w,2)))

#优化算法 梯度函数
#dw = yixi ; db = yi, u是步长(学习率)
#w = w+ udw
#b = b+ udb

#每训练完一个epoch之后更新一次变量
deltal = 0
for i in range(10000):
    print(i)
    deltal = 0
    for i in range(20):

        index = random.randint(0,len(X)-1)
        # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index])
        # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1))
        if((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]<=0):
            print("error")
            w1 = w1+ u*Y[index]*X[index][0]
            w2 = w2 + u*Y[index]*X[index][1]
            b = b + u*Y[index]
            # print(math.sqrt(pow(w1,2)+ pow(w2,2)))
            deltal =deltal + (X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1)/math.sqrt(pow(w1,2)+ pow(w2,2))
            # print(deltal)
    print('损失函数是:{}, w1 is:{}, w2 is:{}, b is:{}'.format(deltal, w1, w2, b))
    if(deltal==0):
        break

  

     得到的结果如下:

0
error
error
error
error
error
error
error
损失函数是:-58.3676307788504, w1 is:9.600000000000001, w2 is:-11.899999999999999, b is:-2.5
1
error
error
error
error
error
error
error
损失函数是:-79.6906079022878, w1 is:16.1, w2 is:-27.4, b is:-5.0
2
error
error
error
error
error
error
error
error
error
error
error
损失函数是:-181.72511941015193, w1 is:19.1, w2 is:-35.4, b is:-7.5
3
error
error
error
error
error
error
error
error
error
损失函数是:-134.1050074696168, w1 is:52.1, w2 is:-21.4, b is:-9.0
4
error
error
error
error
error
error
error
error
error
error
error
损失函数是:-81.79882696995062, w1 is:21.1, w2 is:-49.9, b is:-12.5
5
error
error
error
error
error
error
error
error
error
error
error
损失函数是:-119.81313066808934, w1 is:23.6, w2 is:-19.9, b is:-14.0
6
error
error
error
error
损失函数是:-53.31304070579386, w1 is:21.1, w2 is:-17.4, b is:-15.0
7
error
error
error
error
error
损失函数是:-41.93949987798639, w1 is:31.6, w2 is:9.600000000000001, b is:-16.5
8
error
error
error
损失函数是:-5.054308599796722, w1 is:17.1, w2 is:-14.899999999999999, b is:-18.0
9
error
error
损失函数是:-5.061614611137055, w1 is:9.600000000000001, w2 is:-20.9, b is:-19.0
10
error
error
error
error
error
error
error
error
error
error
error
损失函数是:-176.0012076253418, w1 is:30.1, w2 is:6.100000000000001, b is:-21.5
11
error
error
error
error
error
error
error
error
error
error
损失函数是:-125.48127046860397, w1 is:37.1, w2 is:-13.899999999999999, b is:-24.5
12
error
error
error
error
error
error
error
error
error
error
error
损失函数是:-125.1563990295305, w1 is:7.100000000000001, w2 is:-41.9, b is:-28.0
13
error
error
error
error
error
error
error
error
损失函数是:-99.9588692981525, w1 is:31.6, w2 is:-34.4, b is:-30.0
14
error
error
error
error
error
error
error
error
损失函数是:-123.83992037821682, w1 is:39.1, w2 is:-43.4, b is:-32.0
15
error
error
error
error
error
error
error
error
error
error
损失函数是:-107.52054118479901, w1 is:38.099999999999994, w2 is:-61.4, b is:-35.0
16
error
error
error
error
error
error
损失函数是:-12.707209628891478, w1 is:53.599999999999994, w2 is:-18.4, b is:-36.0
17
error
error
error
error
error
error
error
损失函数是:-23.948721417884677, w1 is:40.099999999999994, w2 is:-28.4, b is:-38.5
18
error
error
error
error
error
error
error
error
损失函数是:-139.31131437239534, w1 is:31.099999999999994, w2 is:-56.4, b is:-40.5
19
error
error
error
error
error
error
error
损失函数是:-95.98994407835531, w1 is:48.099999999999994, w2 is:-62.9, b is:-42.0
20
error
error
error
error
error
error
error
error
error
损失函数是:-92.55322975282904, w1 is:78.6, w2 is:-39.9, b is:-43.5
21
error
error
损失函数是:-17.88492872798486, w1 is:55.099999999999994, w2 is:-71.9, b is:-44.5
22
error
error
损失函数是:-18.30431876055715, w1 is:54.099999999999994, w2 is:-67.9, b is:-44.5
23
error
error
损失函数是:-19.673217617324354, w1 is:53.099999999999994, w2 is:-63.900000000000006, b is:-44.5
24
error
error
损失函数是:-21.10211583473629, w1 is:52.099999999999994, w2 is:-59.900000000000006, b is:-44.5
25
error
error
error
error
error
error
error
error
损失函数是:-65.34583910969604, w1 is:66.6, w2 is:-47.400000000000006, b is:-46.5
26
error
error
error
error
error
error
error
error
error
损失函数是:-74.24796184936555, w1 is:81.6, w2 is:-42.400000000000006, b is:-49.0
27
error
error
error
损失函数是:-34.124857156842516, w1 is:63.099999999999994, w2 is:-65.4, b is:-49.5
28
error
error
error
error
error
error
error
error
error
error
损失函数是:-20.592229958037294, w1 is:52.099999999999994, w2 is:-56.400000000000006, b is:-52.5
29
error
error
error
error
error
error
error
error
损失函数是:-15.858697189511131, w1 is:58.099999999999994, w2 is:-24.900000000000006, b is:-54.5
30
error
error
error
error
error
损失函数是:-59.07572886107646, w1 is:52.099999999999994, w2 is:-24.900000000000006, b is:-55.0
31
error
error
error
error
error
损失函数是:-33.3306474537016, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5
32
损失函数是:0, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5

Process finished with exit code 0

重复运行多次,会有不同的结果。感知机算法由于采用不同的初值和选取不同的误分类点,解可以不同

以上为个人理解,如有不对的地方,欢迎交流指正~

猜你喜欢

转载自www.cnblogs.com/xiaoheizi-12345/p/13191641.html