内容出自于李航的《统计学习方法》此处目的主要是实现其中算法
import java.util.Scanner;
public class Test1_1 {
public static double[]w = {0.0,0.0} ;//初始值
public static double b =0.0;//初始值
public static int number;//记录迭代次数
public static int N=1000;//记录数据多少
public static boolean flag = true;//记录是否需要继续迭代
static class Data{
double x1 = 0.0;
double x2 = 0.0;
int y = 0;
}
public static void main(String[] args) throws InterruptedException {
Scanner scanner = new Scanner(System.in);
//模拟产生w=(1,1) b=-1;
Data datas[] = new Data[N];
for(int i = 0 ; i< N;i++){
datas[i] = new Data();
datas[i].x1 = Math.random()*2;
datas[i].x2 = Math.random()*2;
if((datas[i].x1*1+datas[i].x2*1-1)<=0)
{
datas[i].y = -1;
}else datas[i].y = 1;
}
//开始感知机算法部分 其中学习率为 1
while (flag){
for (int i = 0; i < N; i++){
flag = false;
if(datas[i].y*(datas[i].x1*w[0]+datas[i].x2*w[1]+b)<=0){
flag = true;
w[0] = w[0] + 1 * datas[i].y * datas[i].x1;
w[1] = w[1] + 1 * datas[i].y * datas[i].x2;
b = b + datas[i].y;
System.out.println("第"+(number+1)+"此迭代:");
System.out.print("w:("+w[0]+","+w[1]+") ");
System.out.println(" b:"+b);
number++;
break;
}
}
}
System.out.println("总迭代次数:"+number);
System.out.print("w:("+w[0]+","+w[1]+") ");
System.out.println(" b:"+b);
}
}
结果如下
可以发现,当数据组数越大时,学习情况越好,
预想值w(1,1),b=(-1)
学习值w(4.047,4.035) b=-4.0;
学习情况很接近了