Java实现KNN(K-nearest neighborhood)

课程作业。

如果我没理解错的话,KNN就是

1. 计算待分类点与已知点的欧式距离

2. 根据距离排序

3. 在范围内计算每个类的已知点的个数

4. 出现已知点最多的那个类就视为待分类点的类



代码里用的已知点是用excel随机生成的,分为了三个类,每个类三十个点



只实现了功能,没有优化代码,不过如果考虑性能的话也没必要用Java来写了。。

import java.util.*;
import java.io.*;

public class KnnUse{
	public static void main(String[] args)
	{
		int classNum = 3;	//how many class we have
		int dimension = 3;	//the point's dimension
		int range = 20;		//k-nearest range = k
		String dataPath = "/home/stpraha/Desktop/KNNdata.csv";
		KnnClassify knn = new KnnClassify(dataPath, dimension, classNum);
		int classifyResult = knn.Classify(range, new double[] {0.615354709,0.20145842,0.61402677});
		System.out.println("The classify result is: " + classifyResult);	
	}
}

class KnnClassify
{
	private String filePath;
	private int dimension;
	private int classNum;
	private LinkedList<distanceAndClass> list = new LinkedList<distanceAndClass>();
	double[][] axisData =  {{6.1321582144,4.685213213,8.1417121221},{7.12487893,6.123115455,5.11545661},{9.44679060165588,5.321657536,7.15509904678304},{9.79526231893458,6.640888419,7.34285510993896},
	{4.77761465491205,5.674499364,8.10880050465455},{9.68833577559596,7.514095764,7.3273409886164},{7.20709307877799,7.178310023,4.68203079597704},{9.06883226786243,4.238236761,7.44221202087844},
	{6.22475171322394,5.687983428,4.38984288025036},{8.58441867916441,7.226113777,9.13233010742196},{5.99566194237124,9.866044126,5.59022410837582},
	{6.1388246252093,9.257754963,5.43900595313342},{7.15999435525307,5.727612593,7.97955578774912},{9.32252335774415,5.326093937,9.35169455478495},
	{4.01505470852214,6.201475842,4.61406887707296},{5.80884845393397,5.38273827,7.10742419817661},{5.89891954410616,4.223074801,5.1810495368954},
	{4.08693243273475,4.774333344,4.33110369241029},{8.58691973052222,5.303183952,7.09245386013323},{4.0298598197669,5.40102355,4.74019231176667},
	{9.80519248144921,4.5354452,6.69642029780013},{7.76782235412439,9.321334979,6.33799170370875},{6.91103394225716,9.351885349,5.6359844857326},
	{4.87450403761148,6.447533701,4.97511404730764},{5.76418325895154,4.142384284,8.88306791639143},{8.18214821568598,4.549198405,9.52789778064696},
	{4.84799433814581,5.867685017,5.73111593033916},{9.41296295347156,4.318493688,9.36945061351813},{6.22377771054667,8.57024766,6.06290640783125},{6.10814462821441,0.135781944,3.24342611324956},
	{6.1473012973997,3.831824726,2.03077675379352},{9.75681136517769,4.823726801,3.36548053533234},{7.47846326474974,4.867244449,1.7536709703457},{9.56262772253768,3.437613628,4.29771700057556},
	{5.2402185458792,4.322175435,5.92584949045604},{8.79261249214447,2.452853986,9.31133055017988},{6.66932175906236,5.909352726,5.21583827172789},{9.94207953888806,4.563875612,1.04616032804075},
	{5.16389920501087,1.888782789,7.13243187029199},{8.53362133928223,5.589058851,5.82012869834427},
	{7.72365348120155,5.961339832,7.15026778846361},{8.18785012740612,4.757999461,2.12655611573517},{7.20464612194617,4.731936762,4.08315172416153},{8.4639203787726,0.442752109,4.1107970206962},
	{8.69634978470303,3.55925944,7.80357698300718},{9.72196238186626,3.153092923,7.6191514657492},{6.94257485445718,3.344195601,4.93195995164083},{5.66650557490717,2.18736791,1.38760545980825},
	{5.51666238146926,2.005865827,5.7758673005361},{9.87738690426171,5.681270785,3.8183380732637},{6.86800732737168,0.61269626,9.49324343978673},{6.2593227054612,4.686138423,0.226595986760518},
	{5.13195880431701,0.006550889,3.9451114716201},{7.19514792784954,0.988708237,0.901109967085093},{5.88494987692637,2.831401722,9.73662636599028},{5.97289881935312,1.895373927,5.41008822979565},
	{6.31460350154249,5.312851902,3.13888393871109},{5.49241265814024,4.02364613,4.77710593860612},
	{7.57250344878882,5.436076761,4.71912352505689},{7.32164985376341,2.471078659,0.888507658001142},{0.580879295283057,5.317657905,5.66059144408722},{0.268748997891333,9.116979379,1.57506086565508},
	{3.10630876559569,3.95445715,0.0352144373642296},{4.87565543993233,3.817912989,1.21987577009568},
	{1.03797700224263,9.468454594,3.60537415993729},{0.947511655535578,7.994405955,1.63951127396781},{6.35923052681892,5.965198325,4.92681759059438},{3.19985429405129,5.445125311,2.55111711502164},
	{3.29053382829984,8.060325874,3.24515502488008},{1.93811996462737,6.494508534,1.40379540593345},{0.423589547398009,9.320525076,3.04641630918443},{5.69714153274838,2.718700555,0.578065557772626},
	{5.75763954822419,9.281467884,1.37125846027564},{0.850782173359781,5.974325854,3.04584961978115},{5.19174055707074,1.538150063,4.41537475395151},{3.03067516228081,5.294667858,4.45957357207296},
	{2.53664407123745,6.188508952,4.02526949788536},{0.774254827780066,6.160624484,3.44470500937769},{3.81220980323376,5.90896187,1.54850337715448},{0.994824823423188,4.575440034,0.291045026680756},
	{3.79193402266271,9.039067475,2.43845650679798},{4.88984496830391,5.274188419,0.307723319901437},{3.94567007652296,6.865964537,0.542955409306552},{2.56945333821223,0.368455462,2.0301898602399},
	{3.07244720131968,8.94782325,5.44289828457378},{2.2535664987324,8.384309326,1.57989075375354},{1.96223264826904,2.012881765,4.37686101760724},{2.75416307982501,9.547518142,3.27904060925972},
	{4.07816824353363,0.800816585,1.33368346849364},{3.05391567853604,2.973268049,2.67446765210255}
	};
	int[] classData = {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3};
	
	class distanceAndClass{
		double distanceWithPoint;
		int classOfPoint;
		distanceAndClass(double distanceWithPoint, int classOfPoint)
		{
			this.distanceWithPoint = distanceWithPoint;
			this.classOfPoint = classOfPoint;
		}
	}
	
	class DACComparator implements Comparator<distanceAndClass>{
		@Override
		public int compare(distanceAndClass c1, distanceAndClass c2) {
			// TODO Auto-generated method stub
			if(c1.distanceWithPoint-c2.distanceWithPoint > 0) return 1;
			else if(c1.distanceWithPoint-c2.distanceWithPoint < 0) return -1;
			else return 0;
		}		
	}
	
	KnnClassify(String filePath, int dimension, int classNum)
	{
		this.filePath = filePath;
		this.dimension = dimension;
		this.classNum = classNum;
	}

	public int Classify(int range, double[] pointToClassify)
	{
		int[] classCount = new int[classNum+1];
		
		
		//calculate the oushiDistance between the POINT_TO_CLASSIFY and other known points
		for(int i = 0; i < axisData.length; i++)
		{
			double distance = 0;
			for(int j = 0; j < dimension; j++)
			{
				distance += (pointToClassify[j] - axisData[i][j]) * (pointToClassify[j] - axisData[i][j]);
			}
			//store them in a list
			list.add(new distanceAndClass(distance, classData[i]));
		}
		
		//sort the list by oushiDistance
		Collections.sort(list, new DACComparator());
		
		//in range, count neighborhood class 
		for(int i = 0; i < range; i++)
		{
			classCount[list.removeFirst().classOfPoint]++;
		}
		
		//find out the class with most number of neighborhood
		int result = 0;
		int max = 0;
		for(int i = 0; i < classNum; i++)
		{
			if(classCount[i] > max)
			{
				result = i;
				max = classCount[i];
			}	
		}
		
		return result;
	}
	
	private double[][] readCsvToIntArray(String filePath)
	{
		//read file is not need here
		return null;
	}
}



猜你喜欢

转载自blog.csdn.net/sinat_15901371/article/details/80711905