多元假
设函数
多元代价函数
梯度下降
Repeat{
}
学习率 太大可能导致不能收敛,太小会很慢,经过太多次迭代才能到底
怎么判断收敛,一种是循环很大的次数,反正约到最后移动很慢,第二个是如果后面一次值比前面一次值变化很小(比如0.0001)就认为收敛。
多元函数是每次同时更新每个
下面是实现代码
package com.zy.ml;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
/**
* 多元线性回归批量梯度下降
* @author yzhang
*
*/
public class LinearRegressionBGD {
public static void main(String[] args) {
LinearRegressionBGD lr = new LinearRegressionBGD(0.0001,100000,0.0001);
lr.getTrainDate("/home/zy/Desktop/test.txt");
lr.getTheta();
}
static List<String> trainData = new ArrayList<String>();
static double alpha = 0.01;
static Double[] theta = new Double[] { 1.0, 1.0, 1.0 };
static Double[] tmptheta = new Double[] { 1.0, 1.0, 1.0 };
static int m = 0;
static int time = 100000;
static double end = 0.0001;
public LinearRegressionBGD() {
}
public LinearRegressionBGD(double alpha, int time, double end) {
}
public void getTrainDate(String file) {
LineIterator li = null;
try {
li = FileUtils.lineIterator(new File(file));
} catch (IOException e) {
e.printStackTrace();
}
for (; li.hasNext();) {
String line = li.nextLine();
trainData.add(line);
}
m = trainData.size();
}
/**
* 每一个偏导数的值 theta_j偏导数= 1/m*[sum(h(x)-y)*x_j]
*
* @param i
* 训练集行数time
* @param j
* 参数下标
*/
private Double computerDerivativeofCostFunction(int j) {
double sum = 0;
for (int i = 0; i < trainData.size(); i++) {
// h(x_i)值
double x = 0.0;
// 计算h(x)
for (int k = 0; k < theta.length; k++) {
double n = 1;
if (k != 0) {
n = Double.valueOf(trainData.get(i).split(",")[k - 1]);
}
x += (theta[k] * n);
}
// 求导后后面乘以的x_i theta0没有所以j=0时,是1
double d = 1;
if (j != 0) {
d = Double.valueOf(trainData.get(i).split(",")[j - 1]);
}
// (h(x_i)-y_i)*x_i
double z = (x - Double.valueOf(trainData.get(i).split(",")[trainData.get(i).split(",").length - 1])) * d;
// ∑(h(x_i)-y_i)*x_i
sum = sum + z;
}
// 1/m × ∑(h(x_i)-y_i)*x_i
double costfun = sum / m;
return costfun;
}
public void getTheta() {
int c = 1;
boolean flg = false;
while ((time--) > 0) {
System.out.println("第" + (c++) + "次");
tmptheta = theta.clone();
for (int i = 0; i < theta.length; i++) {
theta[i] -= alpha * computerDerivativeofCostFunction(i);
System.out.println("theta" + i + ":" + theta[i]);
}
//判断终止条件
for (int i = 0; i < theta.length; i++) {
if (Math.abs(theta[i] - tmptheta[i]) < end) {
flg = true;
} else {
flg = false;
}
}
if (flg) {
return;
}
}
}
}