逻辑回归 代价函数 Java实现

数据集

X

0,1,2,3,4,5,6,7,8,9,10

Y

0,0,0,0,0,1,1,1,1,1,1

这里可以看出 当X大于4时 Y等于1

逻辑回归代价函数计算公式


右侧为正规化 但是这里我们并不加入正规化 因为已经足够明显了

daima

package ojama;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Vector;

public class CostFunction {
	public static void main(String[] args) throws IOException {
		Double[] x1 = CostFunction.read("C:/Users/BenQ/Desktop/X.txt");
		Double[] y = CostFunction.read("C:/Users/BenQ/Desktop/Y.txt");
		int m = y.length;
		Double[] x0 = new Double[m];
		for (int i = 0; i < x0.length; i++) {
			x0[i] = 1.0;
		}
		List<Double[]> X = new Vector<Double[]>();
		X.add(x0);
		X.add(x1);
		Double[] theta = GradientDescent.getTheta(X, y);
		for (int i = 0; i < theta.length; i++) {
			System.out.println(String.format("%.2f", theta[i]));
		}
		int num = 0;
		for (int i = 0; i < m; i++) {
			Double sum = 0.0;
			for (int k = 0; k < theta.length; k++) {
				// 在二元图形中,这里相当于k*x+b*1,三元相当于a*x+b*y+c*1,以此类推
				sum += theta[k] * X.get(k)[i];
			}
			num += y[i] * Math.log(sigmoid(sum)) + (1 - y[i]) * Math.log(1 - sigmoid(sum));
		}
		System.out.println(-num / m);
	}

	public static double sigmoid(double z) {
		return 1 / (1 + Math.pow(Math.E, -z));
	}

	public static Double[] read(String fileName) throws IOException {
		File file = new File(fileName);
		FileReader fileReader = new FileReader(file);
		BufferedReader reader = new BufferedReader(fileReader);
		StringBuilder sb = new StringBuilder();
		String str = reader.readLine();
		while (str != null) {
			sb.append(str);
			str = reader.readLine();
		}
		reader.close();
		fileReader.close();
		String[] X0 = sb.toString().replace(" ", "").split(",");
		Double[] x0 = new Double[X0.length];
		for (int i = 0; i < x0.length; i++) {
			x0[i] = Double.parseDouble(X0[i]);
		}
		return x0;
	}
}

输出结果


猜你喜欢

转载自blog.csdn.net/jidong2622/article/details/79373880
今日推荐