随机梯度下降和批量梯度下降的简单代码实现

        最近刚刚开始看斯坦福的机器学习公开课,第一堂课讲到了梯度下降,然后就简单实现了一下。本人学渣一枚,如有错误,敬请指出。

     

/**
 * Created by Administrator on 2016/4/16 0016.
 */
public class GradientDescent {
    private static double[][] data = {
            {3.8, 192.0314202},
            {3.5, 194.1168421},
            {4, 195.1114837},
            {4.4, 197.7640977},
            {4.1, 196.8811122},
            {4.6, 202.9643527},
            {3.6, 191.245283},
            {3.2, 189.2631579},
            {3.4, 189.9758454},
            {3, 187.6717949},
            {3.9, 193.5243902},
            {3.1, 189.2704403},
            {2.2, 177.248366},
            {3.7, 189.296875},
            {3.3, 189.5043478},
            {4.2, 199.6857143},

    };
    //根据excel得到的回归方程:y = 9.3581x + 158.3,数据来自日常的一个项目

    public static void main(String[] args) {
        stochastic(data);

        batch(data);
    }

    /*
    * 当rate = 0.01时
    * 循环2000左右的时候值就不变化了
    * parameter is 157.90981024717982 9.482991891267803
    * error is 47.897064097242335
    *
    * 当rate = 0.001时
    * 循环30000,最后结果几乎不变
    * parameter is 158.25947581125462 9.36980772136795
    * error is 47.7491293901012
    * */

    private static void stochastic(double[][] data) {
        double[] p = {0, 0};//初始化参数为0
        double rate = 0.001;


        for (int i = 0; i < 30000; i++) {
            for (double[] aData : data) {
                double h = 0, err;
                h += p[0] + p[1] * aData[0];
                err = aData[1] - h;

                //根据每一条数据更新参数
                p[0] += rate * err * 1;
                p[1] += rate * err * aData[0];
            }
        }
        System.out.println("parameter is " + p[0] + " " + p[1]);

        double error = 0;
        for (double[] aData : data) {
            error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);
        }
        System.out.println("error is " + error);
    }


    /*
    * rate = 0.001, 循环次数等于30000时,所计算的结果和excel计算的几乎完全一致
    *
    * parameter is 158.299201608832 9.358074090590318
    * error is 47.74825830393555
    *
    * 批量梯度下经确实更加准确
    * */

    private static void batch(double[][] data) {
        double[] p = {0, 0};
        double rate = 0.001;

        for (int i=0;i<50000;i++){
            double err1 = 0;
            double err2 = 0;

            for (double[] aData:data){
                double h=0;
                h=p[0]+p[1]*aData[0];
                err1 += aData[1] - h;
                err2 += (aData[1]-h)*aData[0];
            }

            //遍历整个数据集之后再更新参数
            p[0] += rate*err1;
            p[1] += rate*err2;
        }

        System.out.println("parameter is " + p[0] + " " + p[1]);

        double error = 0;
        for (double[] aData : data) {
            error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);
        }
        System.out.println("error is " + error);
    }

}


猜你喜欢

转载自blog.csdn.net/AlexZhang67/article/details/51170347
今日推荐