K最近邻算法:给定一些已经训练好的数据,输入一个新的测试数据点,计算包含于此测试数据点的最近的点的分类情况,哪个分类的类型占多数,则此测试点的分类与此相同,所以在这里,有的时候可以复制不同的分类点不同的权重。近的点的权重大点,远的点自然就小点。
package knn;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import tool.DataDealing;
import tool.KNN_Node;
import tool.Matrix_2D;
import tool.ReadData;
public class KNN {
private DataDealing transfer;
private Matrix_2D<String> data;
public KNN(String path) throws IOException {
// TODO 自动生成的构造函数存根
data=new Matrix_2D<String>(ReadData.readDataFile(path));
transfer=new DataDealing(data);
}
public String knnClassification(ArrayList<String> testLine,int k) {
final int col=data.getColDimension(),row=data.getRowDimension();
PriorityQueue<KNN_Node> que=new PriorityQueue<KNN_Node>(k,new Comparator<KNN_Node>() {
@Override
public int compare(KNN_Node o1, KNN_Node o2) {//降序return o2-o1
// TODO 自动生成的方法存根
if(o2.getDistanceWithTest()>o1.getDistanceWithTest()) return 1;
else return -1;
}
});
Set<Integer> initSet=new HashSet<Integer>();
while(initSet.size()<k) initSet.add((int) (row*Math.random()));
for(int xx : initSet)
que.add(new KNN_Node(xx, data.get(xx, col-1), calDistance(data.get(xx), testLine, col-1)));
for(int i=0;i<row;++i) {
double dis=calDistance(data.get(i), testLine, col-1);
if(que.peek().getDistanceWithTest()>dis) {
que.remove();
que.add(new KNN_Node(i, data.get(i, col-1), dis));
}
}
return majority(que);
}
private String majority(PriorityQueue<KNN_Node> pq) {
Map<String, Integer> count=new HashMap<String, Integer>();
while(!pq.isEmpty()) {
KNN_Node node=pq.poll();
if(count.containsKey(node.getClassification()))
count.put(node.getClassification(), count.get(node.getClassification())+1);
else count.put(node.getClassification(), 1);
}
int n=0;
String str="";
for(String s : count.keySet())
if(count.get(s)>n) {
n=count.get(s);
str=s;
}
return str;
}
private double calDistance(ArrayList<String> a,ArrayList<String> b,int len) {
double d=0.0;
for(int i=0;i<len;++i)
d+=Math.pow(transfer.getDouble(a.get(i), i)-transfer.getDouble(b.get(i), i), 2.0);
return Math.sqrt(d);
}
public double reportModelSelf(int k) {
final int row=data.getRowDimension(),col=data.getColDimension();
int count=0;
for(int i=0;i<row;++i)
if(knnClassification(data.get(i), k).equals(data.get(i, col-1))) ++count;
return count/(double)row;
}
public double reportModel(int k,double p) {//训练集的比例
Matrix_2D<String> testData=new Matrix_2D<String>();
final int ntest=(int) (data.getRowDimension()*(1-p)),col=data.getColDimension();
for(int i=0;i<ntest;++i) testData.putLine(data.remove((int)(data.getRowDimension()*Math.random())));
int count=0;
for(int i=0;i<ntest;++i)
if(knnClassification(testData.get(i), k).equals(testData.get(i, col-1))) ++count;
for(int i=0;i<ntest;++i) data.putLine(testData.remove(0));
return count/(double)ntest;
}
public static void main(String[] args) throws IOException {
// TODO 自动生成的方法存根
//divorce.txt,AutismAdultDataPlus.txt,StudentAcademicsPerformance.txt
KNN knnTest=new KNN("AutismAdultDataPlus.txt");
double pp=0.0,p0;
System.out.println("KNN模型准确率:");
for(int k=2;k<31;++k) {
p0=knnTest.reportModelSelf(k);
System.out.println("k="+k+"\t"+p0);
pp+=p0;
}
System.out.println("KNN模型准确率:"+(pp/29));
System.out.println("KNN模型_0.5准确率:");
pp=0.0;
for(int k=2;k<31;++k) {
p0=knnTest.reportModel(k,0.5);
System.out.println("k="+k+"\t"+p0);
pp+=p0;
}
System.out.println("KNN模型_0.5准确率:"+(pp/29));
}
}
package tool;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class DataDealing {
private List<Map<String, Double>> standardList;
public DataDealing(Matrix_2D<String> data) {
standardList=new ArrayList<Map<String,Double>>(data.getColDimension());
for(int j=0;j<data.getColDimension();++j) {
Map<String, Double> tmp=new HashMap<String, Double>();
Set<String> featureSet=new HashSet<String>();
for(int i=0;i<data.getRowDimension();++i) featureSet.add(data.get(i, j));
int id=1;
for(String key : featureSet) tmp.put(key, (double)id++);
standardList.add(tmp);
}
}
public double getDouble(String val,int index) {
return standardList.get(index).get(val);
}
}
package tool;
public class KNN_Node {
private int id;
private String classification;
private double distanceWithTest;
public KNN_Node(int id, String classification, double distanceWithTest) {
super();
this.id = id;
this.classification = classification;
this.distanceWithTest = distanceWithTest;
}
public int getId() {
return id;
}
public String getClassification() {
return classification;
}
public double getDistanceWithTest() {
return distanceWithTest;
}
}
package tool;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class Matrix_2D<T> {
ArrayList<ArrayList<T>> data;
public Matrix_2D() {
// TODO �Զ����ɵĹ��캯�����
data=new ArrayList<ArrayList<T>>();
}
public Matrix_2D(ArrayList<ArrayList<T>> d) {
data=new ArrayList<ArrayList<T>>();
for(ArrayList<T> val : d)
this.putLine(val);
}
public void putLine(ArrayList<T> line) {
ArrayList<T> tmp=new ArrayList<T>();
for(T t : line) tmp.add(t);
data.add(tmp);
}
public int getRowDimension() {
return data.size();
}
public int getColDimension() {
return data.get(0).size();
}
public ArrayList<T> get(int i) {
return data.get(i);
}
public T get(int i,int j) {
return data.get(i).get(j);
}
public T remove(int i,int j) {
return data.get(i).remove(j);
}
public ArrayList<T> remove(int index) {
return data.remove(index);
}
public static String[] subArray(String[] original,String str) {
String[] subArray=new String[original.length-1];
int k=0;
for(String s : original) {
if(!s.equals(str)) subArray[k++]=s;
}
return subArray;
}
public static ArrayList<String> copyArrayList(ArrayList<String> data) {
ArrayList<String> d=new ArrayList<String>();
for(String s : data) d.add(s);
return d;
}
public static String majority(ArrayList<String> labels) {
Map<String, Integer> labelCount=new HashMap<String, Integer>();
for(String s : labels) {
if(!labelCount.containsKey(s)) labelCount.put(s,0);
labelCount.put(s,labelCount.get(s)+1);
}
int count=-1;
String t="";
for(String s : labelCount.keySet()) {
if(labelCount.get(s)>count) {
count=labelCount.get(s);
t=s;
}
}
return t;
}
}
package tool;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
public class ReadData {
public static ArrayList<ArrayList<String>> readDataFile(String path) throws IOException {
ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
File file=new File(path);
if(!file.exists()||!file.isFile()) {
System.out.println(file.getAbsolutePath());
return null;
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String str = "";
while ((str=reader.readLine())!=null) {
String[] tokenizer = str.split(",");
ArrayList<String> s = new ArrayList<String>();
for(int i=0;i<tokenizer.length;i++){
s.add(tokenizer[i]);
}
trainingSet.add(s);
}
reader.close();
//�������ݼ�
for(int i=0;i<trainingSet.size();++i) {
int t=(int) ((trainingSet.size()-i)*Math.random());
trainingSet.add(trainingSet.remove(t));
}
return trainingSet;
}
}