Java实现KNN算法

1. 训练数据

//第一列为类型,后两列为数据
1,1.0,1.0
1,1.1,1.2
1,1.2,1.0
1,1.6,1.5
1,1.3,1.7
1,2.0,2.1
1,2.0,2.2
1,2.3,2.3
2,9.0,9.0
2,9.1,9.2
2,9.2,9.0
2,10.6,10.5
2,10.3,10.7
2,9.6,9.1
2,9.4,10.4
2,10.3,10.3
3,10.0,1.0
3,10.1,1.2
3,10.2,1.0
3,10.6,1.5
3,10.3,1.7
3,10.0,2.1
3,10.0,2.2
3,10.3,2.3

2.源代码

public class KNN_Serial {


    //计算欧氏距离
    private static double O_distance(double a1,double a2,double b1,double b2)
    {
        return Math.sqrt(Math.pow((a1-b1),2)+Math.pow(a2-b2,2));
    }

    //读取文件训练集,格式:x,y,type.
    public static List<XY> readfile(String filename) throws IOException {
        List<XY> list = new ArrayList<XY>();
        BufferedReader bufferedReader = new BufferedReader(
                new FileReader(filename));
        String line = null;
        while ((line = bufferedReader.readLine()) != null) {
            String []a=line.split(",");
            list.add(new XY(Double.parseDouble(a[1]),Double.parseDouble(a[2]),Integer.parseInt(a[0])));
        }
        return list;
    }
    //输入一个x,y求距离
    public static List<XY> alldistance(double x,double y,List<XY> xy)
    {
        for(XY o :xy)
        {
            o.distance=O_distance(x,y,o.x,o.y);
        }
        return  xy;
    }
    //三个数求最大
    public static int threemax(int type1,int type2,int type3)
    {
        if (type1 > type2) {
            if (type1 > type3) {
                return 1;
            }else {
                return 3;
            }
        }else {
            if (type2 > type3) {
                return 2;
            }else {
                return 3;
            }
        }
    }
    //KNN实现
    public int knn(double x,double y,int k) throws IOException {
        List<XY> list=readfile("C:\\Users\\wl105\\Desktop\\knn.txt");
        List<XY> list1=alldistance(x,y,list);
        Collections.sort(list1);
        int type1=0,type2=0,type3=0;
        for(int i=0;i<k;k++)
        {
            XY o=list1.get(i);
            if(o.type==1)
            {
                type1++;
            }
            else if(o.type==2)
            {
                type2++;
            }
            else if(o.type==3)
            {
                type3++;
            }
        }
        return threemax(type1,type2,type3);
    }
}



class XY implements Comparable<XY>{
    public double x,y,distance;
    public int type=0;
    public XY(double x,double y )
    {
        this.x=x;
        this.y=y;
    }
    public XY(double x,double y,int type)
    {
        this.x=x;
        this.y=y;
        this.type=type;
    }
    public int findType()
    {
        return type;
    }
    public void setType(int type)
    {
        this.type=type;
    }


    @Override
    public int compareTo(XY o) {
        if (this.distance < o.distance) {
            return -1;
        }else if (this.distance  > o.distance) {
            return 1;
        }
        return 0;
    }
}

猜你喜欢

转载自blog.csdn.net/weixin_39216383/article/details/80628003
今日推荐