K最近邻:KNN算法

K最近邻算法:给定一些已经训练好的数据,输入一个新的测试数据点,计算包含于此测试数据点的最近的点的分类情况,哪个分类的类型占多数,则此测试点的分类与此相同,所以在这里,有的时候可以复制不同的分类点不同的权重。近的点的权重大点,远的点自然就小点。

package knn;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

import tool.DataDealing;
import tool.KNN_Node;
import tool.Matrix_2D;
import tool.ReadData;

public class KNN {

	private DataDealing transfer;
	private Matrix_2D<String> data;
	
	public KNN(String path) throws IOException {
		// TODO 自动生成的构造函数存根
		data=new Matrix_2D<String>(ReadData.readDataFile(path));
		transfer=new DataDealing(data);
	}
	
	public String knnClassification(ArrayList<String> testLine,int k) {
		final int col=data.getColDimension(),row=data.getRowDimension();
		PriorityQueue<KNN_Node> que=new PriorityQueue<KNN_Node>(k,new Comparator<KNN_Node>() {
			@Override
			public int compare(KNN_Node o1, KNN_Node o2) {//降序return o2-o1
				// TODO 自动生成的方法存根
				if(o2.getDistanceWithTest()>o1.getDistanceWithTest()) return 1;
				else return -1;
			}
		});
		Set<Integer> initSet=new HashSet<Integer>();
		while(initSet.size()<k) initSet.add((int) (row*Math.random()));
		for(int xx : initSet)
			que.add(new KNN_Node(xx, data.get(xx, col-1), calDistance(data.get(xx), testLine, col-1)));
		for(int i=0;i<row;++i) {
			double dis=calDistance(data.get(i), testLine, col-1);
			if(que.peek().getDistanceWithTest()>dis) {
				que.remove();
				que.add(new KNN_Node(i, data.get(i, col-1), dis));
			}
		}
		return majority(que);
	}
	
	private String majority(PriorityQueue<KNN_Node> pq) {
		Map<String, Integer> count=new HashMap<String, Integer>();
		while(!pq.isEmpty()) {
			KNN_Node node=pq.poll();
			if(count.containsKey(node.getClassification()))
				count.put(node.getClassification(), count.get(node.getClassification())+1);
			else count.put(node.getClassification(), 1);
		}
		int n=0;
		String str="";
		for(String s : count.keySet())
			if(count.get(s)>n) {
				n=count.get(s);
				str=s;
			}
		return str;
	}
	
	private double calDistance(ArrayList<String> a,ArrayList<String> b,int len) {
		double d=0.0;
		for(int i=0;i<len;++i) 
			d+=Math.pow(transfer.getDouble(a.get(i), i)-transfer.getDouble(b.get(i), i), 2.0);
		return Math.sqrt(d);
	}
	
	public double reportModelSelf(int k) {
		final int row=data.getRowDimension(),col=data.getColDimension();
		int count=0;
		for(int i=0;i<row;++i)
			if(knnClassification(data.get(i), k).equals(data.get(i, col-1))) ++count;
		return count/(double)row;
	}
	
	public double reportModel(int k,double p) {//训练集的比例
		Matrix_2D<String> testData=new Matrix_2D<String>();
		final int ntest=(int) (data.getRowDimension()*(1-p)),col=data.getColDimension();
		for(int i=0;i<ntest;++i) testData.putLine(data.remove((int)(data.getRowDimension()*Math.random())));
		int count=0;
		for(int i=0;i<ntest;++i)
			if(knnClassification(testData.get(i), k).equals(testData.get(i, col-1))) ++count;
		for(int i=0;i<ntest;++i) data.putLine(testData.remove(0));
		return count/(double)ntest;
	}
	
	public static void main(String[] args) throws IOException {
		// TODO 自动生成的方法存根
		//divorce.txt,AutismAdultDataPlus.txt,StudentAcademicsPerformance.txt
		KNN knnTest=new KNN("AutismAdultDataPlus.txt");
		double pp=0.0,p0;
		System.out.println("KNN模型准确率:");
		for(int k=2;k<31;++k) {
			p0=knnTest.reportModelSelf(k);
			System.out.println("k="+k+"\t"+p0);
			pp+=p0;
		}
		System.out.println("KNN模型准确率:"+(pp/29));
		System.out.println("KNN模型_0.5准确率:");
		pp=0.0;
		for(int k=2;k<31;++k) {
			p0=knnTest.reportModel(k,0.5);
			System.out.println("k="+k+"\t"+p0);
			pp+=p0;
		}
		System.out.println("KNN模型_0.5准确率:"+(pp/29));
	}

}


package tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class DataDealing {
 	private List<Map<String, Double>> standardList;
	
 	public DataDealing(Matrix_2D<String> data) {
 		standardList=new ArrayList<Map<String,Double>>(data.getColDimension());
 		for(int j=0;j<data.getColDimension();++j) {
 			Map<String, Double> tmp=new HashMap<String, Double>();
 			Set<String> featureSet=new HashSet<String>();
 			for(int i=0;i<data.getRowDimension();++i) featureSet.add(data.get(i, j));
 			int id=1;
 			for(String key : featureSet) tmp.put(key, (double)id++);
 			standardList.add(tmp);
 		}
 	}
 	
 	public double getDouble(String val,int index) {
 		return standardList.get(index).get(val);
 	}
}

package tool;

public class KNN_Node {
	private int id;
	private String classification;
	private double distanceWithTest;
	
	public KNN_Node(int id, String classification, double distanceWithTest) {
		super();
		this.id = id;
		this.classification = classification;
		this.distanceWithTest = distanceWithTest;
	}

	public int getId() {
		return id;
	}

	public String getClassification() {
		return classification;
	}

	public double getDistanceWithTest() {
		return distanceWithTest;
	}
}

package tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Matrix_2D<T> {
	ArrayList<ArrayList<T>> data;
	
	public Matrix_2D() {
		// TODO �Զ����ɵĹ��캯�����
		data=new ArrayList<ArrayList<T>>();
	}
	
	public Matrix_2D(ArrayList<ArrayList<T>> d) {
		data=new ArrayList<ArrayList<T>>();
		for(ArrayList<T> val : d)
			this.putLine(val);
	}

	public void putLine(ArrayList<T> line) {
		ArrayList<T> tmp=new ArrayList<T>();
		for(T t : line) tmp.add(t);
		data.add(tmp);
	}
	
	public int getRowDimension() {
		return data.size();
	}
	
	public int getColDimension() {
		return data.get(0).size();
	}
	
	public ArrayList<T> get(int i) {
		return data.get(i);
	}
	
	public T get(int i,int j) {
		return data.get(i).get(j);
	}
	
	public T remove(int i,int j) {
		return data.get(i).remove(j);
	}
	
	public ArrayList<T> remove(int index) {
		return data.remove(index);
	}
	
	public static String[] subArray(String[] original,String str) {
		String[] subArray=new String[original.length-1];
		int k=0;
		for(String s : original) {
			if(!s.equals(str)) subArray[k++]=s;
		}
		return subArray;
	}
	
	public static ArrayList<String> copyArrayList(ArrayList<String> data) {
		ArrayList<String> d=new ArrayList<String>();
		for(String s : data) d.add(s);
		return d;
	}
	
	public static String majority(ArrayList<String> labels) {
		Map<String, Integer> labelCount=new HashMap<String, Integer>();
		for(String s : labels) {
			if(!labelCount.containsKey(s)) labelCount.put(s,0);
			labelCount.put(s,labelCount.get(s)+1);
		}
		int count=-1;
		String t="";
		for(String s : labelCount.keySet()) {
			if(labelCount.get(s)>count) {
				count=labelCount.get(s);
				t=s;
			}
		}
		return t;
	}
}

package tool;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class ReadData {
	public static ArrayList<ArrayList<String>> readDataFile(String path) throws IOException {
		ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
		File file=new File(path);
		if(!file.exists()||!file.isFile()) {
			System.out.println(file.getAbsolutePath());
			return null;
		}
		BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        String str = "";  
        while ((str=reader.readLine())!=null) {
            String[] tokenizer = str.split(",");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            }
            trainingSet.add(s);  
        }  
        reader.close();
        //�������ݼ�
        for(int i=0;i<trainingSet.size();++i) {
        	int t=(int) ((trainingSet.size()-i)*Math.random());
        	trainingSet.add(trainingSet.remove(t));
        }
		return trainingSet;
	}
	
	
}

发布了13 篇原创文章 · 获赞 0 · 访问量 1411

猜你喜欢

转载自blog.csdn.net/qq_34262612/article/details/104131958