机器学习——感知机python可视化实现

简介

参考李航老师出版的《统计学习方法》,用python实现感知机学习的算法

感知机算法

这里贴书中介绍的算法原始形式:
这里写图片描述
具体的推导和逻辑理解可以查看书籍或者网上的其它博客。

感知机算法代码实现

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

def PLA_train(dataSet,plot = False):
    numLines = dataSet.shape[0]
    numFeatures = dataSet.shape[1]
    #模型初始化
    w = np.ones((1, numFeatures-1))
    b = 0.1
    k = 1
    i = 0
    #用梯度下降方法,逐渐调整w和b的值
    while i<numLines:
        if dataSet[i][-1] * (np.sum(w * dataSet[i,0:-1],)+ b) <0:   #y[i](w*x[i]+b)<0
            w = w + k*dataSet[i][-1] * dataSet[i,0:-1]  #w = w + k*y[i]
            b = b + k*dataSet[i][-1]    # b = b + k*y[i]
            i =0
        else:
            i +=1

    return w, b

为了测试该算法,这里简单模拟生成数据进行测试。假设生成的数据都是线性可分的,那么只需要在坐标轴上随机生成大量的数据点,在用一条标准线进行分类。然后用这些分类的数据进行训练,查看训练出的模型与标准线的差距。

模拟生成已分类数据

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

def makePLAData(w,b, numlines):
    w = np.array(w)
    numFeatures = len(w)
    x = np.random.rand(numlines, numFeatures) * 20  #随机产生numlines个数据的数据集
    cls = np.sign(np.sum(w*x,axis=1)+b)    #用标准线 w*x+b=0进行分类
    dataSet = np.column_stack((x,cls))
    #至此样例数据已经生成

    #下面是存储标准分类线,以便显示观察
    x = np.linspace(0, 20, 500)      #创建分类线上的点,以点构线。
    y = -w[...,0] / w[...,1] * x - b / w[...,1]
    rows = np.column_stack((x.T, y.T, np.zeros((500, 1))))
    dataSet = np.row_stack((dataSet, rows))

    return dataSet

数据可视化函数

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

def showFigure(dataSet):
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_title('Linear separable data set')
    plt.xlabel('X')
    plt.ylabel('Y')
    #图例设置
    labels = ['类一', '标准线', '类二', '模型线']
    markers = ['o','.','x','.']
    colors = ['r','y','g','b']
    for i in range(4):
        idx = np.where(dataSet[:,2]==i-1)   #找出同类型的点,返回索引值
        ax.scatter(dataSet[idx, 0], dataSet[idx, 1], marker=markers[i], color=colors[i], label=labels[i], s=10)

    plt.legend(loc = 'upper right')
    plt.show()

#测试
w = [1,-2]
b = 7
nums = 200
dataSet = makePLAData([1,-2],7,nums)
showFigure(dataSet)

测试运行效果如下:
这里写图片描述

完整代码

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def makePLAData(w,b, numlines):
    w = np.array(w)
    numFeatures = len(w)
    x = np.random.rand(numlines, numFeatures) * 20  #随机产生numlines个数据的数据集
    cls = np.sign(np.sum(w*x,axis=1)+b)    #用标准线 w*x+b=0进行分类
    dataSet = np.column_stack((x,cls))
    #至此样例数据已经生成

    #下面是存储标准分类线,以便显示观察
    x = np.linspace(0, 20, 500)      #创建分类线上的点,以点构线。
    y = -w[...,0] / w[...,1] * x - b / w[...,1]
    rows = np.column_stack((x.T, y.T, np.zeros((500, 1))))
    dataSet = np.row_stack((dataSet, rows))

    return dataSet


def showFigure(dataSet):
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_title('Linear separable data set')
    plt.xlabel('X')
    plt.ylabel('Y')
    #图例设置
    labels = ['classOne', 'standarLine', 'classTow', 'modelLine']
    markers = ['o','.','x','.']
    colors = ['r','y','g','b']
    for i in range(4):
        idx = np.where(dataSet[:,2]==i-1)   #找出同类型的点,返回索引值
        ax.scatter(dataSet[idx, 0], dataSet[idx, 1], marker=markers[i], color=colors[i], label=labels[i], s=10)

    plt.legend(loc = 'upper right')
    plt.show()


def PLA_train(dataSet,plot = False):
    numLines = dataSet.shape[0]
    numFeatures = dataSet.shape[1]
    #模型初始化
    w = np.ones((1, numFeatures-1))
    b = 0.1
    k = 1
    i = 0
    #用梯度下降方法,逐渐调整w和b的值
    while i<numLines:
        if dataSet[i][-1] * (np.sum(w * dataSet[i,0:-1],)+ b) <0:
            w = w + k*dataSet[i][-1] * dataSet[i,0:-1]
            b = b + k*dataSet[i][-1]
            i =0
        else:
            i +=1

    #到这里参数w和b已经训练出来了,模型训练到此为止
    #下面是为了存储分类线,以便显示观察。
    x = np.linspace(0,20,500)    #创建分类线上的点,以点构线。
    y = -w[0][0]/w[0][1]*x - b/w[0][1]
    rows = np.column_stack((x.T,y.T,2*np.ones((500,1))))
    dataSet = np.row_stack((dataSet,rows))

    showFigure(dataSet)
    return w, b

#测试:
dataSet = makePLAData([1,-2],7,200)
showFigure(dataSet)
w,b= PLA_train(dataSet,True)

运行效果如下:
这里写图片描述
这里写图片描述

猜你喜欢

转载自blog.csdn.net/u014556057/article/details/81289915