多层感知机结合反向传播调节(MLP-BP)的二分类实例java实现

使用java语言编写感知网络实现简单的二分类,算法实现目标为在以原点为中心的目标内,如果点落在以半径为1的圆内就是类A如果落在圆之外的变长为4的正方形里就是类别B,如果测试结果出现其他值可以标记为X。图片效果展示如图

感知机模型为:输入点数据(x,y)经过隐层10个隐含节点的计算最后输出值判定类型为A还是B。连接线代表不同的权值(虽然线颜色一样,偏置没有画出)

下面是具体实现:main函数:

package mlp;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;

public class mlpmain {
	public static int iter = 1000;  //迭代次数
	public static int traindata = 1000;  //训练数据集数量
	public static int testdata = 100;   //测试集数量

	public static double[] stepSize = {0.1, 0.2, 0.3};  //不同步长学习率大小
	public static double threshold = 0.1;   //阈值
	public static double weightRange = 0.001; //初始阈值规范
	public static double momentum=0.5 ;//动量调节值

	public static void main(String args[]) throws IOException {

		for(int i=0; i<stepSize.length; i++) {
			System.out.println("迭代次数: " + iter + " 训练集: " + traindata + " 测试集: " + testdata + " 学习率: " + stepSize[i] +
					" 判定阈值: " + threshold + " 初始化阈值参数" + weightRange + " 动量调节: " + momentum);
			network mymlp = new network(iter, traindata, testdata, stepSize[i], threshold, weightRange, momentum);

			mymlp.initNetwork(); //初始化网络
			mymlp.train();       //训练
			mymlp.test();	     //测试
			//double a[]=mymlp.getSuc();//保存每次迭代训练的结果成功率
			//for (int j = 0; j < iter; j++) {
			//				System.out.println(a[j]);
			//				}
			System.out.println("完成全部训练时成功率 " + mymlp.getSuccrate()+"%");
			System.out.println("测试成功率 " + mymlp.gettestsuccrate() + "%");
			System.out.println();

			BufferedWriter out;
			//输出测试点的值以及测试结果情况比较
			out = new BufferedWriter(new FileWriter(stepSize[i] + "Resultscomparing.csv"));
			for(int j=0; j<mymlp.getTestPoint().size(); j++)
			{
				out.write(mymlp.getTestPoint().elementAt(j) +"\n");
				out.write(mymlp.gettype()[j]+"   testresult type:"+"\n");
				out.write("\n");
			}
			out.close();
		}
	}
}

网络初始化,前向和反向计算

package mlp;

import java.util.Vector;

public class network {
	private int  iteration;    //迭代次数
	private int train;         //训练集
	private int test;          //测试集
	private double stepsize;   //移动步长学习率
	private double threshold;  //阈值
	private double weighRange; //用于规范初始化权值
	private double momentum;   //动量调节因子
	private int inputsize=2;   //输入点值 
	private int hinddensize=10;//隐层节点值
	private int outputsize=2;  //输出节点个数
	private node[] inputnode;  
	private node[] hiddennode;
	private node[] outputnode;
	//权值大小及更新时所用
	private double [] hinddenDelta;
	private double [] outputDelta;
	private double [][]inputweight;
	private double [][]oldInputeight;
	private double [][]outputweight;
	private double [][]oldoutputweight;
	//用于数据点的生成
	private Vector<datapoint>trainPoint;
	private Vector<datapoint>testPoint;  
	private int success;
	private double succrate; //训练时成功率
	private double testsuccrate;  //测试成功率
	private double suc[]; //保存每次训练成功率的大小
	private char []type;  //保存测试数据的输出类型
	//初始化构造
	public network(int  iteration,int train,int test,double stepsize,double threshold,double weighRange,double momentum) {
		this.iteration=iteration;
		this.train=train;
		this.test=test;
		this.stepsize=stepsize;
		this.threshold=threshold;
		this.weighRange=weighRange;
		this.momentum=momentum;
		this.succrate=0;
		this.suc=new double [iteration];
		this.type=new char[test];
		inputnode=new node[inputsize];
		hiddennode=new node[hinddensize];
		outputnode=new node[outputsize];
		hinddenDelta=new  double [hinddensize];
		outputDelta=new double[outputsize];
		inputweight=new double[inputsize][hinddensize];
		oldInputeight=new double[inputsize][hinddensize];
		outputweight=new double[hinddensize][outputsize];
		oldoutputweight=new double[hinddensize][outputsize];
		trainPoint=new Vector<datapoint>();
		testPoint=new Vector<datapoint>();

	}
	public double[] getSuc() {
		return suc;
	}
	public char[] gettype() {
		return type;
	}
	public Vector<datapoint> getTrainPoint() {
		return trainPoint;
	}
	public Vector<datapoint> getTestPoint() {
		return testPoint;
	}

	public int getsuccess() {
		return success;
	}
	public double getSuccrate() {
		return succrate;
	}
	public double gettestsuccrate() {
		return testsuccrate;
	}
	
	public void initNetwork() {
		initPoints(); //初始化数据集
		initNodes();  //初始化节点参数
		initWeights(weighRange);//初始化权值
	}
	private void initWeights(double weighRange2) {
		// TODO Auto-generated method stub
		for(int i=0;i<inputsize;i++) {
			for(int j=0;j<hinddensize;j++) {
				inputweight[i][j]=randomBais()*weighRange;
				oldInputeight[i][j]=inputweight[i][j];
			}
		}
		for(int i=0;i<hinddensize;i++) {
			for(int j=0;j<outputsize;j++) {
				outputweight[i][j]=randomBais()*weighRange;
				oldoutputweight[i][j]=outputweight[i][j];
			}
		}
	}
	private double randomBais() {
		// TODO Auto-generated method stub
		return Math.random()-1.0;
	}
	private void initNodes() {
		// TODO Auto-generated method stub
		for(int i=0;i<inputsize;i++) {
			inputnode[i]=new node(Math.random(), randomBais());
		}
		for(int i=0;i<hinddensize;i++) {
			hiddennode[i]=new node(Math.random(), randomBais());
		}
		for(int i=0;i<outputsize;i++) {
			outputnode[i]=new node(Math.random(), randomBais());
		}
	}
	private void initPoints() {
		// TODO Auto-generated method stub
		for(int i=0;i<train;i++) {
			trainPoint.addElement(new datapoint());
		}
		for(int i=0;i<test;i++) {
			testPoint.add(new datapoint());
		}
	}
	public void train() {//训练
		while(iteration>0) {
			for(int i=0;i<train;i++) {
				double[] desiredOutput = new double[outputsize];
				inputnode[0].setActivation(trainPoint.elementAt(i).getX());
				inputnode[1].setActivation(trainPoint.elementAt(i).getY());
				if((trainPoint.elementAt(i)).getType() == 'A') {
					desiredOutput[0] = 1.0;          
					desiredOutput[1] = 0.0;
				}else {
					desiredOutput[0] = 0.0;          
					desiredOutput[1] = 1.0;
				}		
				Forward();
				backPropagation(desiredOutput);
				for(int j=1; j<outputsize; j++) 
					if((Math.abs(desiredOutput[j] - outputnode[j].getActivation()) < threshold))
						success++;	
			}
			iteration--;
			succrate= ((success * 100.0)/train);
			suc[suc.length-iteration-1]=succrate;
			success = 0;
		}
	}
	public void test() {//测试
		int suRate=0;
		for(int i=0;i<test;i++) {
		inputnode[0].setActivation(testPoint.elementAt(i).getX());
		inputnode[1].setActivation(testPoint.elementAt(i).getY());		
		Forward();//
		type[i]=classify(outputnode);
		if(!(testPoint.elementAt(i).getType() == classify(outputnode))) 
			testPoint.elementAt(i).setSuccess(false);
		else {
			testPoint.elementAt(i).setSuccess(true);
			suRate++;
		}
		}
		testsuccrate=((suRate * 100.0)/test);;	
	}
	
	private char classify(node[] outputnode2) {
		// TODO Auto-generated method stub
		if(outputnode[0].getActivation() > 0.5 && outputnode[1].getActivation() < 0.5) 
			return 'A';
		else if(outputnode[0].getActivation() < 0.5 && outputnode[1].getActivation() > 0.5) {
			return 'B';
		}
		return 'X';
	}
	private void backPropagation(double[] desiredOutput) {//反向传播
		// TODO Auto-generated method stub
		double temp;
		// calculate error
	    for(int i=0; i<outputsize; i++)
            outputDelta[i] = (outputnode[i].getActivation() * (1 - outputnode[i].getActivation())) *
            				 (desiredOutput[i] - outputnode[i].getActivation());	    
	    for(int i=0; i<hinddensize; i++) {
	    	temp = 0.0;
	    	for(int j=0; j<outputsize; j++) 
	    		temp += outputDelta[j] * outputweight[i][j];
	    	hinddenDelta[i] = hiddennode[i].getActivation() * (1.0 - hiddennode[i].getActivation()) * temp;
	    }       
	    // 更新权值
	    for(int i=0; i<inputsize; i++) {
	    	for(int j=0; j<hinddensize; j++) {
	    		temp = inputweight[i][j] + (stepsize * hinddenDelta[j] * inputnode[i].getActivation()) +
	    			   (momentum * (inputweight[i][j] - oldInputeight[i][j]));
	    		oldInputeight[i][j] = inputweight[i][j];
	    		inputweight[i][j] = temp;
	    	}
	    }
	    
	    for(int i=0; i<hinddensize; i++) {
	    	for(int j=0; j<outputsize; j++) {
	    		temp = outputweight[i][j] + (stepsize * outputDelta[j] * hiddennode[i].getActivation()) +
	    			   (momentum * (outputweight[i][j] - oldoutputweight[i][j]));
	    		oldoutputweight[i][j] = outputweight[i][j];
	    		outputweight[i][j] = temp;
	    	}
	    }
	    
	    // 更新bais
	    for(int i=0; i<hinddensize; i++) {
	    	temp = hiddennode[i].getBias() + (stepsize * hinddenDelta[i]) +
                   (momentum * (hiddennode[i].getBias() - hiddennode[i].getOldbais()));
	    	hiddennode[i].setOldbais(hiddennode[i].getBias());
	    	hiddennode[i].setBias(temp);
	    }
	    
	    for(int i=0; i<outputsize; i++) {
	    	temp = outputnode[i].getBias() + (stepsize * outputDelta[i]) +
	    		   (momentum * (outputnode[i].getBias() - outputnode[i].getOldbais()));
	    	outputnode[i].setOldbais(outputnode[i].getBias());
	    	outputnode[i].setBias(temp);
	    }
	}
	private void Forward() {//前向传播
		// TODO Auto-generated method stub
		double temp;
		for(int i=0;i<hinddensize;i++) {
			temp=0.0;
			for(int j=0;j<inputsize;j++) 
				temp+=inputnode[j].getActivation()*inputweight[j][i];
			hiddennode[i].setActivation(sigmoid(temp+hiddennode[i].getBias()));
			}
		for(int i=0;i<outputsize;i++) {
			temp=0.0;
			for(int j=0;j<hinddensize;j++)
				temp+=hiddennode[j].getActivation()*outputweight[j][i];
			outputnode[i].setActivation(sigmoid(temp+outputnode[i].getBias()));
		}
	}
	private double sigmoid(double d) {//激活函数
		// TODO Auto-generated method stub
		return 1/(1 + Math.exp(-1 * d));
	}
}

节点构造函数:

package mlp;

public class node {
private double activation;
private double bias;
private double oldbais;
public node(double a,double b) {
	this.activation=a;
	this.bias=b;
}
public double getActivation() {
	return this.activation;
}
public void setActivation(double activation) {
	this.activation = activation;
}
public double getBias() {
	return this.bias;
}
public void setBias(double bias) {
	this.bias = bias;
}
public double getOldbais() {
	return this.oldbais;
}
public void setOldbais(double oldbais) {
	this.oldbais = oldbais;
}
public String toString() {
	return this.activation + " " + this.bias;
}
}

数据集构造函数:

package mlp;

public class datapoint {//数据集构造
private double x;
private double y;
private char type;
private boolean success;
public datapoint() {
	this.x=(Math.random()*4)-2;
	this.y=(Math.random()*4)-2;
	if((Math.pow(this.x, 2)+Math.pow(this.y, 2))>1)
		this.type='B';
	else {
		this.type='A';
	}
	this.success=false;
}
public datapoint(double x, double y) {
	super();
	this.x = x;
	this.y = y;
	if((Math.pow(this.x, 2)+Math.pow(this.y, 2))>1)
		this.type='B';
	else {
		this.type='A';
	}
}
public double getX() {
	return x;
}

public double getY() {
	return y;
}
public char getType() {
	return type;
}
public void setType(char type) {
	this.type = type;
}
public void setSuccess(boolean success) {
	this.success = success;
}
@Override
public String toString() {
	if(!this.success)
		return this.type+"," + this.x+"," + "," +this. y+", faile";
	return this.type + ", " + this.x + ", " + this.y + ", success";
}


}

测试结果:可以看到不同学习率的大小测试情况。也以文件形式保存了测试数据集的数据值和测试结果预测类型值

迭代次数: 1000 训练集: 1000 测试集: 100 学习率: 0.1 判定阈值: 0.1 初始化阈值参数0.001 动量调节: 0.5
完成全部训练时成功率 94.3%
测试成功率 99.0%

迭代次数: 1000 训练集: 1000 测试集: 100 学习率: 0.2 判定阈值: 0.1 初始化阈值参数0.001 动量调节: 0.5
完成全部训练时成功率 96.7%
测试成功率 100.0%

迭代次数: 1000 训练集: 1000 测试集: 100 学习率: 0.3 判定阈值: 0.1 初始化阈值参数0.001 动量调节: 0.5
完成全部训练时成功率 94.9%
测试成功率 99.0%

发布了105 篇原创文章 · 获赞 86 · 访问量 15万+

猜你喜欢

转载自blog.csdn.net/dingyahui123/article/details/81206156