【机器学习】Logistic Regression逻辑回归原理与java实现
1、基于概率的机器学习算法
机器学习算法可以分为基于概率、基于距离、基于树和基于神经网络四类。基于概率的机器学习算法本质上是计算每个样本属于对应类别的概率,然后利用极大似然估计法对模型进行训练。基于概率的机器学习算法的损失函数为负的log似然函数。
基于概率的机器学习算法包括朴素贝叶斯算法、Logistic Regression算法、Softmax Regression算法和Factorization Machine算法等。
2、逻辑回归算法原理
2.1、分离超平面
Logistic Regression算法是二分类线性分类算法,分离超平面采用线性函数:
是样本特征矩阵,特征数为
,其中
是权重矩阵。通过分类超平面可以将数据分成正负两个类别,类别为正的样本标签标记为1,类别为负的样本标签标记为0。
2.2、阈值函数
通过阈值函数可以将样本到分离超平面的距离映射到不同的类别,Logistic Regression算法中阈值函数采用Sigmoid函数:
sigmoid函数的图像如下:
对于样本
,其到分离超平面的几何距离
为:
2.3、样本概率
假设样本
为正类别,则其概率为:
负类别样本的概率:
将两种类别合并,属于类别
的概率为:
2.4、损失函数
设训练数据集有 个训练样本 ,其似然函数为:
Logistic Regression算法的损失函数为负的log似然函数:
模型训练是为了求取最优的权值矩阵
和偏置
,将模型训练问题转化为最小化损失函数:
3、基于梯度下降法的模型训练
本博文中,最小化损失函数的求解采用梯度下降法。
第一步:初始化权重矩阵
和偏置
。
第二步:重复如下过程:
计算参数的梯度下降方向:
选择步长
更新参数:
第三步:判断是否达到终止条件。
假设
是样本
的第
个特征分量,
为权重矩阵
的第
个分量,取
,则权重矩阵中第
个分量的梯度方向为:
4、java实现
完整代码和数据样本地址:https://github.com/shiluqiang/Logistic_Regression_java
首先:导入样本特征和标签。
import java.io.*;
public class LoadData {
//导入样本特征
public static double[][] Loadfeature(String filename) throws IOException{
File f = new File(filename);
FileInputStream fip = new FileInputStream(f);
// 构建FileInputStream对象
InputStreamReader reader = new InputStreamReader(fip,"UTF-8");
// 构建InputStreamReader对象
StringBuffer sb = new StringBuffer();
while(reader.ready()) {
sb.append((char) reader.read());
}
reader.close();
fip.close();
//将读入的数据流转换为字符串
String sb1 = sb.toString();
//按行将字符串分割,计算二维数组行数
String [] a = sb1.split("\n");
int n = a.length;
System.out.println("二维数组行数为:" + n);
//计算二维数组列数
String [] a0 = a[0].split("\t");
int m = a0.length;
System.out.println("二维数组列数为:" + m);
double [][] feature = new double[n][m];
for (int i = 0; i < n; i ++) {
String [] tmp = a[i].split("\t");
for(int j = 0; j < m; j ++) {
if (j == m-1) {
feature[i][j] = (double) 1;
}
else {
feature[i][j] = Double.parseDouble(tmp[j]);
}
}
}
return feature;
}
//导入样本标签
public static double[] LoadLabel(String filename) throws IOException{
File f = new File(filename);
FileInputStream fip = new FileInputStream(f);
// 构建FileInputStream对象
InputStreamReader reader = new InputStreamReader(fip,"UTF-8");
// 构建InputStreamReader对象,编码与写入相同
StringBuffer sb = new StringBuffer();
while(reader.ready()) {
sb.append((char) reader.read());
}
reader.close();
fip.close();
//将读入的数据流转换为字符串
String sb1 = sb.toString();
//按行将字符串分割,计算二维数组行数
String [] a = sb1.split("\n");
int n = a.length;
System.out.println("二维数组行数为:" + n);
//计算二维数组列数
String [] a0 = a[0].split("\t");
int m = a0.length;
System.out.println("二维数组列数为:" + m);
double [] Label = new double[n];
for (int i = 0; i < n; i ++) {
String [] tmp = a[i].split("\t");
Label[i] = Double.parseDouble(tmp[m-1]);
}
return Label;
}
}
然后,利用梯度下降算法优化Logistic Regression模型。
public class LRtrainGradientDescent {
int paraNum; //权重参数的个数
double rate; //学习率
int samNum; //样本个数
double [][] feature; //样本特征矩阵
double [] Label;//样本标签
int maxCycle; //最大迭代次数
public LRtrainGradientDescent(double [][] feature, double [] Label, int paraNum,double rate, int samNum,int maxCycle) {
this.feature = feature;
this.Label = Label;
this.maxCycle = maxCycle;
this.paraNum = paraNum;
this.rate = rate;
this.samNum = samNum;
}
// 权值矩阵初始化
public double [] ParaInitialize(int paraNum) {
double [] W = new double[paraNum];
for (int i = 0; i < paraNum; i ++) {
W[i] = 1.0;
}
return W;
}
//计算每次迭代后的预测误差
public double [] PreVal(int samNum,int paraNum, double [][] feature,double [] W) {
double [] Preval = new double[samNum];
for (int i = 0; i< samNum; i ++) {
double tmp = 0;
for(int j = 0; j < paraNum; j ++) {
tmp += feature[i][j] * W[j];
}
Preval[i] = Sigmoid.sigmoid(tmp);
}
return Preval;
}
//计算误差率
public double error_rate(int samNum, double [] Label, double [] Preval) {
double sum_err = 0.0;
for(int i = 0; i < samNum; i ++) {
sum_err += Math.pow(Label[i] - Preval[i], 2);
}
return sum_err;
}
//LR模型训练
public double[] Updata(double [][] feature, double[] Label, int maxCycle, double rate) {
// 先计算样本个数和特征个数
int samNum = feature.length;
int paraNum = feature[0].length;
//初始化权重矩阵
double [] W = ParaInitialize(paraNum);
// 循环迭代优化权重矩阵
for (int i = 0; i < maxCycle; i ++) {
// 每次迭代后,样本预测值
double [] Preval = PreVal(samNum,paraNum,feature,W);
double sum_err = error_rate(samNum,Label,Preval);
if (i % 10 == 0) {
System.out.println("第" + i + "次迭代的预测误差为:" + sum_err);
}
//预测值与标签的误差
double [] err = new double[samNum];
for(int j = 0; j < samNum; j ++) {
err[j] = Label[j] - Preval[j];
}
// 计算权重矩阵的梯度方向
double [] Delt_W = new double[paraNum];
for (int n = 0 ; n < paraNum; n ++) {
double tmp = 0;
for(int m = 0; m < samNum; m ++) {
tmp += feature[m][n] * err[m];
}
Delt_W[n] = tmp / samNum;
}
for(int m = 0; m < paraNum; m ++) {
W[m] = W[m] + rate * Delt_W[m];
}
}
return W;
}
}
Sigmoid函数
public class Sigmoid {
public static double sigmoid(double x) {
double i = 1.0;
double y = i / (i + Math.exp(-x));
return y;
}
}
Logistic Regression模型参数和测试结果存储。
import java.io.*;
public class SaveModel {
public static void savemodel(String filename, double [] W) throws IOException{
File f = new File(filename);
// 构建FileOutputStream对象
FileOutputStream fip = new FileOutputStream(f);
// 构建OutputStreamWriter对象
OutputStreamWriter writer = new OutputStreamWriter(fip,"UTF-8");
//计算模型矩阵的元素个数
int n = W.length;
StringBuffer sb = new StringBuffer();
for (int i = 0; i < n-1; i ++) {
sb.append(String.valueOf(W[i]));
sb.append("\t");
}
sb.append(String.valueOf(W[n-1]));
String sb1 = sb.toString();
writer.write(sb1);
writer.close();
fip.close();
}
public static void saveresults(String filename, double [] pre_results) throws IOException{
File f = new File(filename);
// 构建FileOutputStream对象
FileOutputStream fip = new FileOutputStream(f);
// 构建OutputStreamWriter对象
OutputStreamWriter writer = new OutputStreamWriter(fip,"UTF-8");
//计算预测结果的个数
int n = pre_results.length;
StringBuffer sb = new StringBuffer();
for (int i = 0; i < n-1; i ++) {
sb.append(String.valueOf(pre_results[i]));
sb.append("\n");
}
sb.append(String.valueOf(pre_results[n-1]));
String sb1 = sb.toString();
writer.write(sb1);
writer.close();
fip.close();
}
}
主类。
import java.io.*;
public class LRMain {
public static void main(String[] args) throws IOException{
// filename
String filename = "data.txt";
// 导入样本特征和标签
double [][] feature = LoadData.Loadfeature(filename);
double [] Label = LoadData.LoadLabel(filename);
// 参数设置
int samNum = feature.length;
int paraNum = feature[0].length;
double rate = 0.01;
int maxCycle = 1000;
// LR模型训练
LRtrainGradientDescent LR = new LRtrainGradientDescent(feature,Label,paraNum,rate,samNum,maxCycle);
double [] W = LR.Updata(feature, Label, maxCycle, rate);
//保存模型
String model_path = "wrights.txt";
SaveModel.savemodel(model_path, W);
//模型测试
double [] pre_results = LRTest.lrtest(paraNum, samNum, feature, W);
//保存测试结果
String results_path = "pre_results.txt";
SaveModel.saveresults(results_path, pre_results);
}
}