【Java面向对象】Knn实验

背景介绍

邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法。

图片源自网络

KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

图片源自网络

实验环境

ASUS VivoBook + Windows10 + IntelliJ IDEA 2021.3.2 (Community Edition) + JDK17

实验内容

K近邻算法是分类数据最简单有效的算法,它采用基于实例的学习方法。简单地说,它采用测量不同样本之间距离的方法进行分类。它的工作原理是:存在一个样本数据集合,也称为训练样本集,并且样本集中的每个数据都有标签,即我们知道每个数据所属的分类。输入没有标签的新数据之后,将新数据的每个特征与样本集中数据的对应特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前K个最相似的数据,这就是K近邻算法中K的出处。最后,选择K个最相似数据中出现次数最多的分类。

要求设计和实现类KnnNumber,构造一个使用K近邻分类器的手写识别系统,该系统可以识别数字0到9.需要识别的数字已经处理成具有相同的色彩和大小:使用文本格式表示的32像素*32像素黑白图像(0/1二值图像)。数据集在digits.zip中。其中目录trainingDigits包含了训练样本集,其中包含了大约2000个样本数据,每个数据的文件名表明了它的标签(0~9的某个数字),每个数字大约有200个样本数据;目录testDigits中包含了大约900个测试数据。请合理设计KnnNumber的数据成员和成员方法,以实现算法的各个步骤。将K取值为[3,9]之间的一个整数,找出分类准确率最高的K值。

数据集地址:https://github.com/YourHealer/KNN-DataSet.git

(1)实验思路

①定义类KnnNumber,定义trainData存放训练集数据,定义testData存放测试集数据,定义trainTag存放训练集元素对应的标签,定义trainCount存放训练集长度。

②由于32*32的二进制图像难以获取信息,将文件中的二进制矩阵转化为1*1034的向量就可以简化判断。我们定义方法concertLongVector实现这一功能。

③由于我们的目标在于比较预测准确率,我们需要通过文件名获取二进制图像的真实标签。我们定义方法getTag实现这一功能。

④我们定义方法getTrainData,将训练集文件列表读入,将各文件中的二进制数组转化为向量填充到数组trainData和trainTag中。

⑤我们定义方法knn计算各向量的欧氏距离并升序排列,根据knn算法中不同的n对结果进行筛选。

⑥我们定义方法tagJudge核实测试集的预测结果,对标签的预测情况进行运算并打印出来。

⑦最后我们创建实例,将测试集输入,得到相应结果。

(2)实验源码

KnnNumber:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

public class KnnNumber {
    public int[][] trainData;  //存放训练数据
    public int[] trainTag;     //存放训练数据对应标签
    public int trainCount;     //记录训练数据量

    //将32x32的二进制图像转换为1x1024的值向量并返回
    public int[] convertLongVector(String path) throws IOException
    {
        int[] longVector = new int[1024];
        BufferedReader br = new BufferedReader(new FileReader(path));
        char[] str = new char[34];
        //对原数组的每一行操作
        for(int i=0;i<32;i++)
        {
            int ret = br.read(str,0,34);
            //对原数组的每一列进行操作
            for(int j=0;j<32;j++){
                longVector[i*32+j]= str[j]-48;
            }
        }
        br.close();

        return longVector;
    }

    //根据文件名返回二进制值图像的标签
    public int getTag(String path)
    {
        return(Integer.parseInt(path.split("_")[0]));
    }

    public void getTrainData(String filename) throws IOException //建立训练数据和对应的分类标签
    {
        String[] trainFile=new File(filename).list();  //获取文件的子文件名
        //假定文件非空
        if (trainFile != null) {
            trainCount =trainFile.length;
            trainData = new int[trainCount][1024];        //存储测试数据
            trainTag = new int[trainCount];             //存储测试数据标签
            for(int i = 0; i< trainCount; i++)
            {
                trainData[i] = convertLongVector(filename+"\\"+trainFile[i]);
                trainTag[i] = getTag(trainFile[i]);
            }
        }

    }

    //根据算法原理预测二进制值图像的标签
    public int knn(int[] testData,int k)
    {
        double[][] dis = new double[trainCount][2];
        int[][] testVector = new int[trainCount][1024];
        int[][] resMatrix = new int[trainCount][1024];
        int[] tagCnt = new int[10];
        int maxTimes = 0;
        int maxTag = 0;

        for(int i = 0; i< trainCount; i++){
            //向量扩展
            System.arraycopy(testData, 0, testVector[i], 0, 1024);
        }

        //计算每个向量的欧氏距离
        for(int i = 0; i< trainCount; i++)
        {
            for(int j = 0; j < 1024; j++){
                resMatrix[i][j]=(int)Math.pow(testVector[i][j]- trainData[i][j], 2);
            }
        }

        for(int i = 0; i< trainCount; i++)
        {
            dis[i][0]=0;
            dis[i][1]=i;
            for(int j = 0;j< 1024; j++){
                dis[i][0]+=resMatrix[i][j];
            }
            dis[i][0]=Math.sqrt(dis[i][0]);
        }

        //对欧氏距离数组进行升序排序
        Arrays.sort(dis,new Comparator<double[]>(){
            @Override
            public int compare(double[] a,double[] b) {
                return Double.compare(a[0], b[0]);
            }
        });

        //根据参数k确定筛选数量
        for(int i = 0; i<k; i++)
        {
            tagCnt[trainTag[(int)dis[i][1]]]++;
        }

        //找到该二进制图像的最可能的预测标签
        for(int i=0;i<10;i++){
            if(tagCnt[i]>maxTimes) //找出出现次数最多的标签
            {
                maxTimes=tagCnt[i];
                maxTag=i;
            }
        }
        return maxTag;
    }

    //识别二进制图像标签,得出准确率
    public void tagJudge(String path) throws IOException
    {
        String[] testFile = new File(path).list();
        int[] testData;
        int testCount = 0;
        int trueLabel;
        int predictLabel;
        HashMap<Integer, Double> map = new HashMap<Integer, Double>();

        if (testFile != null) {
            testCount = testFile.length;
        }

        for(int k = 3; k<= 9; k++)
        {
            int wrongCnt=0;
            for(int i = 0; i< testCount; i++)
            {
                testData = convertLongVector(path+"\\"+testFile[i]);
                trueLabel = getTag(testFile[i]);
                predictLabel = knn(testData,k);

                if(predictLabel!=trueLabel){
                    wrongCnt++;
                }

                System.out.println("第" +(i+1)+ "组数据真实值为:"+trueLabel+",真实值为:" +trueLabel + "。");
            }

            double accuracy=(1-(double)wrongCnt/testCount)*100;
            System.out.println("k值为:" +k+ "时,准确率约为:"+String.format("%.4f",accuracy)+"%。");
            map.put(k, accuracy);
        }

        List<HashMap.Entry<Integer, Double> > list = new ArrayList<Map.Entry<Integer,Double>>(map.entrySet());
        list.sort(new Comparator<HashMap.Entry<Integer, Double>>() {
            @Override
            public int compare(HashMap.Entry<Integer, Double> a, HashMap.Entry<Integer, Double> b) {
                return b.getValue().compareTo(a.getValue());
            }
        });
        System.out.println();
        System.out.println("所有k值对应的准确率如下:");

        for(HashMap.Entry<Integer, Double> ele:list)
        {
            System.out.println("k:"+ele.getKey()+"  "+"准确率约为:"+String.format("%.4f",ele.getValue())+ "%");
        }
    }

}

KnnNumberTest:

import java.io.IOException;

public class KnnNumberTest {

    public static void main(String[] args) throws IOException {
        //创建实例进行测试
        KnnNumber knn = new KnnNumber();
        knn.getTrainData("C:\\Users\\ASUS\\Desktop\\FolderOne\\digits\\trainingDigits");   //建立训练数据和对应的分类标签,我把测试集与训练集文件放在桌面的FoldOne文件夹的digits子文件夹里
        knn.tagJudge("C:\\Users\\ASUS\\Desktop\\FolderOne\\digits\\testDigits");              //识别测试数据的标签,得出准确率
    }
}

(4)实验心得

此题中共有三个难点。

第一个难点在于训练集与测试集的处理。如何利用IO流对数据进行基本处理。

第二个难点在于文件内容处理,如何获取真实标签、如何预测可能的标签、如何判定二进制数字最为相像、如何计算向量距离。

第三个难点在于对不同情况下的k进行结果运算与结果筛选。

通过本次学习,我切实了解到如何使用IO流解决实际问题,受益匪浅。

猜你喜欢

转载自blog.csdn.net/ayaishere_/article/details/128711920
今日推荐