直线拟合最小二乘法

数据

x=(1,2,3,4,5)

y = (1,1.5,3,4.5,5)

算法结果

在这里插入图片描述

R语言运行结果

在这里插入图片描述

算法原理

x的均值:

xp=sum(x1+x2+x3+…+xn)/n

y的均值 :

yp=sum(y1+y2+y3+…+yn)/n

x的平方差之和:

lxx=sum( (xi-xp) ^ 2 )

协方差之和

lxy=sum( (xi-xp)*(yi-yp) )

拟合直线 y’=kx+b

k=lxy / lxx

b=yp-k*xp

代码实现

(数据容器选用集合,这样可以把数据当向量运算)

集合求和方法

	public  static double sum(List<Number> c) {
		try {
			Objects.requireNonNull(c);
		}catch (Exception e) {
			return Double.NaN;
		}
		double ret=0;
		Iterator<Number> itr=c.iterator();
		while(itr.hasNext()) {
			ret+=itr.next().doubleValue();
		}
		return ret;
	}

由于计算协方差需要集合内所有数作乘法,与求和一样,都是对连续计算每个元素,不如定义一个连续计算方法。

集合连续计算方法

第一个参数为一个集合
第二个参数为函数接口,入参1为记录值,入参2为待计算值,出参为计算后的值,是下次迭代的入参1。

	//Continuous computation
	public  static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {
		try {
			Objects.requireNonNull(c);
		}catch (Exception e) {
			return Double.NaN;
		}
		Number ret=null;
		Iterator<Number> itr=c.iterator();
		while(itr.hasNext()) {
			if(ret==null) {
				ret=itr.next();
			}else{
				ret=fun.apply(ret,itr.next().doubleValue());	
			//	System.out.println("");
			}
		}
		return ret.doubleValue();
	}

使用方法,(计算方法由函数接口决定):
可以求和:
conc(x,(r,n)->r.doubleValue()+n.doubleValue());
可以求积:
conc(x,(r,n)->r.doubleValue()*n.doubleValue());

对集合内每一个元素进行计算更新

计算 xi-xp

	public static  BiFunction< List<Number>,Function<Number,Number>,List<Number> > 
	calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {
		public List<Number> apply(List<Number> a, Function<Number,Number> fun) {
			try {
				Objects.requireNonNull(a);
				Objects.requireNonNull(fun);
			}catch (Exception e) {
				return null;
			}
			List<Number> b=new ArrayList();
			a.forEach(itme->{
				b.add(fun.apply(itme));
			});
			return b;
		}
	};

多个集合运算生成新集合

主要计算协方差----两个集合相乘 (xi-xp)*(yi-yp)

	public static  BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> > 
	call=new BiFunction<  List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {
		public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {
			try {
				Objects.requireNonNull(c);
				Objects.requireNonNull(fun);
			}catch (Exception e) {
				return null;
			}
			int width=c.size();
			List<Number> b=new ArrayList();
			int height=c.get(0).size();
			for(int h=0;h<height;h++) {
				List<Number> tmp=new ArrayList();
				for(int w=0;w<width;w++) {
					 tmp.add(c.get(w).get(h));
				} 
				b.add(fun.apply(tmp));
			}
			return b;
		}
	};

拟合算法

	public static double[] lineFit(List<Number> x,List<Number> y){
		double xp=mean(x);
		List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp);
		double lxx = sum(  calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) ); 
		System.out.println(lxx);
		double yp=mean(y);
		List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);
		double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),
				          e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));
		double k=lxy/lxx;
		double b=yp-k*xp;
		return new double[] {k,b};
	}

完整代码

package utility;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;

import javafx.application.Application;
import javafx.scene.Parent;
import javafx.scene.Scene;
import javafx.scene.chart.LineChart;
import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.stage.Stage;
public class LineFit extends Application{
	public  static double sum(List<Number> c) {
		try {
			Objects.requireNonNull(c);
		}catch (Exception e) {
			return Double.NaN;
		}
		double ret=0;
		Iterator<Number> itr=c.iterator();
		while(itr.hasNext()) {
			ret+=itr.next().doubleValue();
		}
		return ret;
	}
	//Continuous computation
	public  static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {
		try {
			Objects.requireNonNull(c);
		}catch (Exception e) {
			return Double.NaN;
		}
		Number ret=null;
		Iterator<Number> itr=c.iterator();
		while(itr.hasNext()) {
			if(ret==null) {
				ret=itr.next();
			}else{
				ret=fun.apply(ret,itr.next().doubleValue());	
			}
		}
		return ret.doubleValue();
	}
	public  static  double mean(List<Number> c) {
		try {
			Objects.requireNonNull(c);
		}catch (Exception e) {
			return Double.NaN;
		}
		double ret = c.size()>0? sum(c)/c.size():0;
		return ret;
	}
	public static double[] lineFit(List<Number> x,List<Number> y){
		double xp=mean(x);
		List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp);	
		double lxx = sum(  calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) ); 
		double yp=mean(y);
		List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);
		double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),
				          e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));
		double k=lxy/lxx;
		double b=yp-k*xp;
		return new double[] {k,b};
	}

	public static void main(String[] args) {
		 launch();
	}

	@Override
	public void start(Stage primaryStage) throws Exception {
		List<Number> x = Arrays.asList(1,2,3,4,5);
		List<Number> y = Arrays.asList(1,1.5,3,4.5,5);
		primaryStage.setScene(new Scene(plot(x,y)));
		primaryStage.show();
	}
	
	public Series<Number, Number> test(List<Number> x,List<Number> y){
		double[] kb = lineFit(x,y);
		double k = kb[0];
		double b = kb[1];
		Iterator<Number> xi = x.iterator();
		Series<Number, Number> series = new LineChart.Series<Number,Number>();
		while(xi.hasNext()) {
			Number tmp = xi.next();
			series.getData().add(new Data(tmp,k*tmp.doubleValue()+b));
		}
		series.setName("拟合直线 "+"Y="+k+"x+("+String.format("%.2f",b)+")");
		return series;
	}
	public  Series<Number, Number> data(List<Number> x,List<Number> y){
		Series<Number, Number> series = new LineChart.Series<Number,Number>();
		Iterator<Number> xi = x.iterator();
		Iterator<Number> yi = y.iterator();
		while(xi.hasNext()&&yi.hasNext()) {
			series.getData().add(new Data(xi.next(),yi.next()));
		}
		series.setName("data");
		return series;
	}
	public  Parent plot(List<Number> x,List<Number> y) {	
		NumberAxis xAxis=new NumberAxis();
		NumberAxis yAxis=new NumberAxis();
		LineChart chart = new LineChart(xAxis, yAxis);  

		chart.getData().add(data(x,y));
		chart.getData().add(test(x,y));
		return chart;
   }
	public static  BiFunction< List<Number>,Function<Number,Number>,List<Number> > 
	calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {
		public List<Number> apply(List<Number> a, Function<Number,Number> fun) {
			try {
				Objects.requireNonNull(a);
				Objects.requireNonNull(fun);
			}catch (Exception e) {
				return null;
			}
			List<Number> b=new ArrayList();
			a.forEach(itme->{
				b.add(fun.apply(itme));
			});
			return b;
		}
	};
	public static  BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> > 
	call=new BiFunction<  List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {
		public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {
			try {
				Objects.requireNonNull(c);
				Objects.requireNonNull(fun);
			}catch (Exception e) {
				return null;
			}
			int width=c.size();
			List<Number> b=new ArrayList();
			int height=c.get(0).size();
			for(int h=0;h<height;h++) {
				List<Number> tmp=new ArrayList();
				for(int w=0;w<width;w++) {
					 tmp.add(c.get(w).get(h));
				} 
				b.add(fun.apply(tmp));
			}
			return b;
		}
	};
}

R语言代码

x=c(1,2,3,4,5)
y = c(1,1.5,3,4.5,5)
data1=data.frame(x=x,y=y) 
lm.data1<-lm(y~x,data=data1)
b<-round(lm.data1$coefficients[1],3) 
k<-round(lm.data1$coefficients[2],3)
plot(data1$x,data1$y,xlab="x",ylab = "y",col="red",pch="*") 
abline(lm.data1,col="blue") 
text(mean(data1$x),max(data1$y),paste("y = ",k,"x+(",k,")",sep = ""))

猜你喜欢

转载自blog.csdn.net/qq_39464369/article/details/90038571