本文算法来源于李航的《统计学习方法》。
本文主要实现其中的两种算法,并做两种时间的对比。
不难看出这里面有固定不变的部分
是不变的可以用gram矩阵储存,减少后续计算部分,这个也是计算时间大幅度减少的主要原因。
public class Test1_1 {
public static double[]w = {0.0,0.0} ;//初始值
public static double b =0.0;//初始值
public static int number;//记录迭代次数
public static int N=10000;//记录数据多少
public static double arg = 1.0;
public static boolean flag = true;//记录是否需要继续迭代
public static double gram[][] = new double[N][N];//对偶形式的Gram矩阵
public static Data datas[] = new Data[N];//数据
public static double a[] = new double[N];//对偶形式中的ai
static class Data{
double x1 = 0.0;
double x2 = 0.0;
int y = 0;
}
public static void main(String[] args) throws InterruptedException {
Scanner scanner = new Scanner(System.in);
//模拟产生w=(1,1) b=-1;
product();//模拟数据
long startTime = System.currentTimeMillis();
FUN();//原始形式
long stopTime = System.currentTimeMillis();
System.out.println("原始形式:"+(stopTime-startTime)+"ms");
startTime = System.currentTimeMillis();
FUN2();//对偶形式
stopTime = System.currentTimeMillis();
System.out.println("对哦形式:"+(stopTime-startTime)+"ms");
}
private static void FUN2( ){
//计算Gram
for(int i = 0 ; i < N ;i++) {
for(int j = 0; j < N ;j++){
gram[i][j] = datas[i].x1*datas[j].x1+ datas[i].x2*datas[j].x2;
}
}
//进行学习
while(flag){
flag = false;
for (int i = 0; i < N; i++) {
if(misclassification(i)){
a[i]+=arg;
b = b+arg*datas[i].y;
flag = true;
number++;
break;
}
}
}
for(int i = 0 ; i < N ; i++) {
if (datas[i].y > 0) {
w[0] += (a[i] * datas[i].x1);
w[1] += (a[i] * datas[i].x2);
}else {
w[0] -= (a[i] * datas[i].x1);
w[1] -= (a[i] * datas[i].x2);
}
}
System.out.println("对偶形式迭代次数:"+number);
System.out.print("w:("+w[0]+","+w[1]+") ");
System.out.println(" b:"+b);
}
private static boolean misclassification(int id) {
double ans = b;
for(int i=0;i<N;i++) {
ans += (a[i] * datas[i].y * gram[i][id]);
}
System.out.println(ans * datas[id].y > 0);
if(ans * datas[id].y > 0)return false;
else return true;
}
//模拟产生w(1,1) b = -1
private static void product() {
for(int i = 0 ; i< N;i++){
datas[i] = new Data();
datas[i].x1 = Math.random()*2;
datas[i].x2 = Math.random()*2;
if((datas[i].x1*1+datas[i].x2*1-1)<=0)
{
datas[i].y = -1;
}else datas[i].y = 1;
}
}
private static void FUN() {
//开始感知机算法部分 其中学习率为 1
while (flag){
for (int i = 0; i < N; i++){
flag = false;
if(datas[i].y*(datas[i].x1*w[0]+datas[i].x2*w[1]+b)<=0){
flag = true;
w[0] = w[0] + 1 * datas[i].y * datas[i].x1;
w[1] = w[1] + 1 * datas[i].y * datas[i].x2;
b = b + datas[i].y;
number++;
break;
}
}
}
System.out.println("原始形式迭代次数:"+number);
System.out.print("w:("+w[0]+","+w[1]+") ");
System.out.println(" b:"+b);
}
}
从中看出,对偶形式时间明显优于原始形式,并且代码不算复杂