感知机的原始形式和对偶形式,并说明Gram矩阵的用处及对比两种形式的时间

本文算法来源于李航的《统计学习方法》。
本文主要实现其中的两种算法,并做两种时间的对比。
在这里插入图片描述
不难看出这里面有固定不变的部分
在这里插入图片描述
是不变的可以用gram矩阵储存,减少后续计算部分,这个也是计算时间大幅度减少的主要原因。

public class Test1_1 {
    public static double[]w = {0.0,0.0} ;//初始值
    public static double b =0.0;//初始值
    public static int number;//记录迭代次数
    public static int N=10000;//记录数据多少
    public static double arg = 1.0;
    public static boolean flag = true;//记录是否需要继续迭代
    public static double gram[][] = new double[N][N];//对偶形式的Gram矩阵
    public static Data datas[] = new Data[N];//数据
    public static double a[] = new double[N];//对偶形式中的ai
    static class Data{
        double x1 = 0.0;
        double x2 = 0.0;
        int y = 0;
    }
    public static void main(String[] args) throws InterruptedException {
        Scanner scanner = new Scanner(System.in);
        //模拟产生w=(1,1) b=-1;
     product();//模拟数据
        long startTime =  System.currentTimeMillis();
       FUN();//原始形式
        long stopTime =  System.currentTimeMillis();
        System.out.println("原始形式:"+(stopTime-startTime)+"ms");
        startTime =  System.currentTimeMillis();
       FUN2();//对偶形式
        stopTime =  System.currentTimeMillis();
        System.out.println("对哦形式:"+(stopTime-startTime)+"ms");
    }

    private static void FUN2( ){
        //计算Gram
        for(int i = 0 ; i < N ;i++) {
            for(int j = 0; j < N ;j++){
                gram[i][j] = datas[i].x1*datas[j].x1+ datas[i].x2*datas[j].x2;
            }
        }
        //进行学习
        while(flag){
            flag = false;
            for (int i = 0; i < N; i++) {
                if(misclassification(i)){
                    a[i]+=arg;
                    b = b+arg*datas[i].y;

                    flag = true;
                    number++;
                    break;
                }
            }
        }
        for(int i = 0 ; i < N ; i++) {
            if (datas[i].y > 0) {
                w[0] += (a[i] * datas[i].x1);
                w[1] += (a[i] * datas[i].x2);
            }else {
                w[0] -= (a[i] * datas[i].x1);
                w[1] -= (a[i] * datas[i].x2);
            }
        }
        System.out.println("对偶形式迭代次数:"+number);
        System.out.print("w:("+w[0]+","+w[1]+") ");
        System.out.println("   b:"+b);
    }

    private static boolean misclassification(int id) {
        double ans = b;
        for(int i=0;i<N;i++) {
            ans += (a[i] * datas[i].y * gram[i][id]);
        }
        System.out.println(ans * datas[id].y > 0);
        if(ans * datas[id].y > 0)return false;
        else return  true;
    }

   //模拟产生w(1,1) b = -1
    private static void product() {
        for(int i = 0 ; i< N;i++){
            datas[i] = new Data();
            datas[i].x1 = Math.random()*2;
            datas[i].x2 = Math.random()*2;
            if((datas[i].x1*1+datas[i].x2*1-1)<=0)
            {
                datas[i].y = -1;
            }else datas[i].y = 1;
        }
    }

    private static void FUN() {
        //开始感知机算法部分 其中学习率为 1
        while (flag){
            for (int i = 0; i < N; i++){
                flag = false;
                if(datas[i].y*(datas[i].x1*w[0]+datas[i].x2*w[1]+b)<=0){
                    flag = true;
                    w[0] = w[0] + 1 * datas[i].y * datas[i].x1;
                    w[1] = w[1] + 1 * datas[i].y * datas[i].x2;
                    b = b + datas[i].y;
                    number++;
                    break;
                }

            }
        }
        System.out.println("原始形式迭代次数:"+number);
        System.out.print("w:("+w[0]+","+w[1]+") ");
        System.out.println("   b:"+b);
    }
}

在这里插入图片描述
从中看出,对偶形式时间明显优于原始形式,并且代码不算复杂

发布了30 篇原创文章 · 获赞 62 · 访问量 3088

猜你喜欢

转载自blog.csdn.net/weixin_43981664/article/details/90244498