线性、逻辑回归的java实现

  线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。

  将样本设为Sample类,其中采用数组作为特征的存储形式。

1. 样本类Sample

 1 public class Sample {
 2     
 3     double[] features;
 4     int feaNum; // the number of sample's features
 5     double value; // value of sample in regression
 6     int label; // class of sample
 7     
 8     public Sample(int number) {
 9         feaNum = number;
10         features = new double[feaNum];
11     }
12     
13     public void outSample() {
14         System.out.println("The sample's features are:");
15         for(int i = 0; i < feaNum; i++) {
16             System.out.print(features[i] + " ");
17         }
18         System.out.println();
19         System.out.println("The label is: " + label);
20         System.out.println("The value is: " + value);
21     }
22 }

2. 抽象类Regression

public abstract class Regression {

    double[] theta; //parameters
    int paraNum; //the number of parameters
    double rate; //learning rate
    Sample[] sam; // samples
    int samNum; // the number of samples
    double th; // threshold value
    
    /**
     * initialize the samples
     * @param s : training set
     * @param num : the number of training samples
     */
    public void Initialize(Sample[] s, int num) {
        samNum = num;
        sam = new Sample[samNum];
        for(int i = 0; i < samNum; i++) {
            sam[i] = s[i];
        }
    }
    
    /**
     * initialize all parameters
     * @param para : theta
     * @param learning_rate 
     * @param threshold 
     */
    public void setPara(double[] para, double learning_rate, double threshold) {
        paraNum = para.length;
        theta = para;
        rate = learning_rate;
        th = threshold;
    }
    
    /**
     * predicte the value of sample s
     * @param s : prediction sample
     * @return : predicted value
     */
    public abstract double PreVal(Sample s);
    
    /**
     * calculate the cost of all samples
     * @return : the cost
     */
    public abstract double CostFun();
    
    /**
     * update the theta
     */
    public abstract void Update();
    
    public void OutputTheta() {
        System.out.println("The parameters are:");
        for(int i = 0; i < paraNum; i++) {
            System.out.print(theta[i] + " ");
        }
        System.out.println(CostFun());
    }
}

3. 线性回归LinearRegression

public class LinearRegression extends Regression{

    public double PreVal(Sample s) {
        double val = 0;
        for(int i = 0; i < paraNum; i++) {
            val += theta[i] * s.features[i];
        }
        return val;
    }
    
    public double CostFun() {
        double sum = 0;
        for(int i = 0; i < samNum; i++) {
            double d = PreVal(sam[i]) - sam[i].value;
            sum += Math.pow(d, 2);
        }
        return sum / (2*samNum);
    }
    
    public void Update() {
         double former = 0; // the cost before update
         double latter = CostFun(); // the cost after update
         double d = 0;
         double[] p = new double[paraNum];
         do {
             former = latter;
             //update theta
             for(int i = 0; i < paraNum; i++) {
                 // for theta[i]
                 for(int j = 0; j < samNum; j++) {
                     d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
                 }
                 p[i] -= (rate * d) / samNum;
             }
             theta = p;
             latter = CostFun();
         }while(former - latter > th);
    }

}

4. 逻辑回归LogisticRegression

public class LogisticRegression extends Regression{

    public double PreVal(Sample s) {
        double val = 0;
        for(int i = 0; i < paraNum; i++) {
            val += theta[i] * s.features[i];
        }
        return 1/(1 + Math.pow(Math.E, -val));
    }

    public double CostFun() {
        double sum = 0;
        for(int i = 0; i < samNum; i++) {
            double p = PreVal(sam[i]);
            double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p);
            sum += d;
        }
        return -1 * (sum / samNum);
    }
    
    public void Update() {
         double former = 0; // the cost before update
         double latter = CostFun(); // the cost after update
         double d = 0;
         double[] p = new double[paraNum];
         do {
             former = latter;
             //update theta
             for(int i = 0; i < paraNum; i++) {
                 // for theta[i]
                 for(int j = 0; j < samNum; j++) {
                     d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i];
                 }
                 p[i] -= (rate * d) / samNum;
             }
             latter = CostFun();
         }while(former - latter > th);
         theta = p;
    }
}

5. 使用的线性回归样本

x0 x1 x2 x3 x4 y
1 2104 5 1 45 460
1 1416 3 2 40 232
1 1534 3 2 30 315
1 852 2 1 36 178
1 1254 3 3 45 321
1 987 2 2 35 241
1 1054 3 2 30 287
1 645 2 3 25 87
1 542 2 1 30 94
1 1065 3 1 25 241
1 2465 7 2 50 687
1 2410 6 1 45 654
1 1987 4 2 45 436
1 457 2 3 35 65
1 587 2 2 25 54
1 468 2 1 40 87
1 1354 3 1 35 215
1 1587 4 1 45 345
1 1789 4 2 35 325
1 2500 8 2 40 720

6. 线性回归测试

import java.io.IOException;
import java.io.RandomAccessFile;

public class Test {

    public static void main(String[] args) throws IOException {
        //read Sample.txt
        Sample[] sam = new Sample[25];
        int w = 0;
                
        long filePoint = 0;
        String s;
        RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r");
        long fileLength = file.length();
                
        while(filePoint < fileLength) {
            s = file.readLine();
            //s --> sample
            String[] sub = s.split(" ");
            sam[w] = new Sample(sub.length - 1);
            for(int i = 0; i < sub.length; i++) {
                if(i == sub.length - 1) {
                    sam[w].value = Double.parseDouble(sub[i]);
                }
                else {
                    sam[w].features[i] = Double.parseDouble(sub[i]);
                }
            }//for
            w++;
            filePoint = file.getFilePointer();
        }//while read file
        
        LinearRegression lr = new LinearRegression();
        double[] para = {0,0,0,0,0};
        double rate = 0.5;
        double th = 0.001;
        lr.Initialize(sam, w);
        lr.setPara(para, rate, th);
        lr.Update();
        lr.OutputTheta();
    }
    
}

7. 使用的逻辑回归样本

x0 x1 x2 class
1 0.23 0.35 0
1 0.32 0.24 0
1 0.6 0.12 0
1 0.36 0.54 0
1 0.02 0.89 0
1 0.36 -0.12 0
1 -0.45 0.62 0
1 0.56 0.42 0
1 0.4 0.56 0
1 0.46 0.51 0
1 1.2 0.32 1
1 0.6 0.9 1
1 0.32 0.98 1
1 0.2 1.3 1
1 0.15 1.36 1
1 0.54 0.98 1
1 1.36 1.05 1
1 0.22 1.65 1
1 1.65 1.54 1
1 0.25 1.68 1

8. 逻辑回归测试

import java.io.IOException;
import java.io.RandomAccessFile;

public class Test {

    public static void main(String[] args) throws IOException {
        //read Sample.txt
        Sample[] sam = new Sample[25];
        int w = 0;
                
        long filePoint = 0;
        String s;
        RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r");
        long fileLength = file.length();
                
        while(filePoint < fileLength) {
            s = file.readLine();
            //s --> sample
            String[] sub = s.split(" ");
            sam[w] = new Sample(sub.length - 1);
            for(int i = 0; i < sub.length; i++) {
                if(i == sub.length - 1) {
                    sam[w].label = Integer.parseInt(sub[i]);
                }
                else {
                    sam[w].features[i] = Double.parseDouble(sub[i]);
                }
            }//for
            //sam[w].outSample();
            w++;
            filePoint = file.getFilePointer();
        }//while read file
        
        LogisticRegression lr = new LogisticRegression();
        double[] para = {0,0,0};
        double rate = 0.5;
        double th = 0.001;
        lr.Initialize(sam, w);
        lr.setPara(para, rate, th);
        lr.Update();
        lr.OutputTheta();
    }
    
}

猜你喜欢

转载自www.cnblogs.com/datamining-bio/p/9240378.html
今日推荐