李航《统计学习方法》感知机代码

李航《统计学习方法》感知机代码

import numpy as np
import time

def loadData(filename):
    '''
    加载minist
    训练集数量:60000
    测试集数量:10000
    '''
    print('start to read data')
    #存放数据及标记的list
    dataArr = []
    labelArr= []
    #打开文件,只读形式打开
    fr = open(filename,'r')
    #将文件逐行读取
    for line in fr.readlines():
        #line.strip()删除换行符,line.split(',')以','分割
        curLine = line.strip().split(',')
        #Mnist有0-9个标记,由于是二分类任务,所以将>=5的作为1,<5的作为-1
        if int(curLine[0]) >= 5:
            labelArr.append(1)
        else:
            labelArr.append(-1)
        #存放标记
        dataArr.append([int(num)/255 for num in curLine[1:]])

    return dataArr,labelArr

def perceptron(dataArr,labelArr,iter = 50):
    '''
    感知机训练过程
    可以理解为将mnist数据集中28*28=784大小的像素值看做是不同维度的数值
    相当于在784维空间中寻找一个超平面,可以将>=5和<5进行很好的分类
    输入:dataArr训练集
                  labelArr训练集标签
                  iter迭代循环次数,默认为50
    输出:训练权重w和偏置b
    '''
    print('start to train')
    #将数据转换为矩阵形式
    dataMat = np.mat(dataArr)           #m行n列
    labelMat = np.mat(labelArr).T    #m行1列
    
    m,n = np.shape(dataMat)
    #w为1行n列
    w = np.zeros((1,n))
    b = 0
    h = 0.0001

    for k in range(iter):
        for i in range(m):
            xi = dataMat[i]  #1行n列
            yi = labelMat[i] #1行1列
            #"-yi * (w * xi.T + b)"就是感知机的导数
            if -1 * yi * (w * xi.T + b) >= 0:
                #w和b的更新
                w = w + h * yi * xi
                b = b + h * yi
        print('Round %d:%d training' %(k, iter))

    return w, b

def model_test(dataArr,labelArr,w,b):
    '''
    测试准确率
    输入:dataArr测试数据集
                 labelArr测试集标签
                 w 训练得到的权重w
                 b 训练得到的偏置b
    返回:准确率
    '''
    print('start to test')
    dataMat = np.mat(dataArr)    #m*n
    labelMat = np.mat(labelArr).T # n* 1

    m,n = np.shape(dataMat)
    #errorCnt记录错误样本数
    errorCnt = 0
    for i in range(m):
        xi = dataMat[i]
        yi = labelMat[i]
        result = -1 * yi * (w * xi.T + b)

        if result >= 0:
            errorCnt += 1
    accruRate = 1 - (errorCnt / m)
    return accruRate

if __name__ == '__main__':
    #获取当前时间
    start = time.time()
    #获取训练集及标签
    trainData,trainLabel = loadData('./mnist_train.csv')
    testData,testLabel = loadData('./mnist_test.csv')

    w,b = perceptron(trainData,trainLabel,30)
    accruRate = model_test(testData,testLabel,w,b)
    end = time.time()
    print('accuracy rate is:',accruRate)
    print('time span:',end - start)

猜你喜欢

转载自blog.csdn.net/m0_45388819/article/details/113751574