【Javaオブジェクト指向】Knn実験

背景紹介

近接アルゴリズム、または K-最近傍 (KNN、K-NearestNeighbor) 分類アルゴリズムは、データ マイニング分類テクノロジの最も単純な方法の 1 つです。いわゆる K 最近傍とは、K 個の最近傍を意味します。これは、各サンプルがその最も近い K 個の近傍によって表現できることを意味します。最近傍アルゴリズムは、データ セット内の各レコードを分類する方法です。

画像はインターネットから来ました

KNN アルゴリズムの中心的な考え方は、特徴空間内のサンプルの K 個の最近傍サンプルのほとんどが特定のカテゴリに属している場合、そのサンプルもこのカテゴリに属し、このカテゴリ内のサンプルの特性を持つということです。分類の決定を行う際、この方法では、最も近い 1 つまたは複数のサンプルのカテゴリに従って、分割するサンプルのカテゴリを決定するだけです。KNN メソッドは、カテゴリの決定を行う際に、非常に少数の隣接するサンプルのみに関連します。KNN 法は、クラス領域を識別して属するカテゴリを決定する方法ではなく、主に周囲の限られたサンプルに依存するため、KNN 法は、クラス領域が交差するときにサンプル セットを分割する他の方法よりも正確です。フィット感を高めるために、さらに重なり合うこともあります。

画像はインターネットから来ました

研究室環境

ASUS VivoBook + Windows10 + IntelliJ IDEA 2021.3.2 (コミュニティ版) + JDK17

実験内容

K 最近傍アルゴリズムは、データを分類するための最も単純かつ効果的なアルゴリズムであり、インスタンスベースの学習方法を使用します。簡単に言うと、異なるサンプル間の距離を測定することで分類を実行します。その動作原理は次のとおりです。トレーニング サンプル セットとも呼ばれるサンプル データ セットがあり、サンプル セット内の各データにはラベルが付いています。つまり、各データが属する分類がわかっています。ラベルなしで新しいデータを入力した後、新しいデータの各特徴がサンプル セット内のデータの対応する特徴と比較され、アルゴリズムによってサンプル セット内で最も類似したデータ (最近傍) の分類ラベルが抽出されます。一般に、サンプル データ セット内で最も類似した上位 K 個のデータのみを選択します。これが、K 最近傍アルゴリズムの K のソースになります。最後に、K 個の最も類似したデータの中で最も多く出現したカテゴリが選択されます。

0 から 9 までの数字を認識できる K 最近傍分類器を使用した手書き認識システムを構築するには、クラス KnnNumber を設計および実装する必要があります。認識される数字は同じ色とサイズになるように処理されています: 32テキスト形式で表されるピクセル* 32 ピクセルの白黒画像 (0/1 バイナリ画像)。データセットはdigits.zip内にあります。trainingDigits ディレクトリにはトレーニング サンプル セットが含まれており、これには約 2000 のサンプル データが含まれ、各データのファイル名はそのラベル (0 ~ 9 の特定の数字) を示し、各番号には約 200 のサンプル データが含まれ、ディレクトリ testDigits には約 900 のテストが含まれます。データ。アルゴリズムの各ステップを実現するために、KnnNumber のデータ メンバーとメンバー メソッドを合理的に設計してください。最も高い分類精度を持つ K 値を見つけるには、K を [3,9] の間の整数として設定します。

データセットのアドレス: https://github.com/YourHealer/KNN-DataSet.git

(1) 実験アイデア

① クラス KnnNumber を定義、トレーニング セットのデータを格納する trainData を定義、テスト セットのデータを格納する testData を定義、トレーニング セットの要素に対応するタグを格納する trainTag を定義、トレーニング セットの長さを格納する trainCount を定義します。 。

②32*32のバイナリ画像では情報が得にくいため、ファイル内のバイナリ行列を1*1034のベクトルに変換すると判定が簡単になります。この機能を実現するために、メソッドconcertLongVectorを定義します。

③ 私たちの目的は予測精度を比較することなので、ファイル名を通じてバイナリ画像の実際のラベルを取得する必要があります。この機能を実現するために getTag メソッドを定義します。

④ getTrainData メソッドを定義し、トレーニング セット ファイルのリストを読み込み、各ファイル内のバイナリ配列をベクトルに変換し、配列 trainData と trainTag に埋めます。

⑤ 各ベクトルのユークリッド距離を計算して昇順に並べるメソッド knn を定義し、knn アルゴリズムの異なる n に従って結果をフィルタリングします。

⑥テストセットの予測結果を検証し、タグの予測を計算して出力するメソッドtagJudgeを定義します。

⑦最後に、インスタンスを作成し、テストセットを入力し、対応する結果を取得します。

(2) 実験ソースコード

Knn番号:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

public class KnnNumber {
    public int[][] trainData;  //存放训练数据
    public int[] trainTag;     //存放训练数据对应标签
    public int trainCount;     //记录训练数据量

    //将32x32的二进制图像转换为1x1024的值向量并返回
    public int[] convertLongVector(String path) throws IOException
    {
        int[] longVector = new int[1024];
        BufferedReader br = new BufferedReader(new FileReader(path));
        char[] str = new char[34];
        //对原数组的每一行操作
        for(int i=0;i<32;i++)
        {
            int ret = br.read(str,0,34);
            //对原数组的每一列进行操作
            for(int j=0;j<32;j++){
                longVector[i*32+j]= str[j]-48;
            }
        }
        br.close();

        return longVector;
    }

    //根据文件名返回二进制值图像的标签
    public int getTag(String path)
    {
        return(Integer.parseInt(path.split("_")[0]));
    }

    public void getTrainData(String filename) throws IOException //建立训练数据和对应的分类标签
    {
        String[] trainFile=new File(filename).list();  //获取文件的子文件名
        //假定文件非空
        if (trainFile != null) {
            trainCount =trainFile.length;
            trainData = new int[trainCount][1024];        //存储测试数据
            trainTag = new int[trainCount];             //存储测试数据标签
            for(int i = 0; i< trainCount; i++)
            {
                trainData[i] = convertLongVector(filename+"\\"+trainFile[i]);
                trainTag[i] = getTag(trainFile[i]);
            }
        }

    }

    //根据算法原理预测二进制值图像的标签
    public int knn(int[] testData,int k)
    {
        double[][] dis = new double[trainCount][2];
        int[][] testVector = new int[trainCount][1024];
        int[][] resMatrix = new int[trainCount][1024];
        int[] tagCnt = new int[10];
        int maxTimes = 0;
        int maxTag = 0;

        for(int i = 0; i< trainCount; i++){
            //向量扩展
            System.arraycopy(testData, 0, testVector[i], 0, 1024);
        }

        //计算每个向量的欧氏距离
        for(int i = 0; i< trainCount; i++)
        {
            for(int j = 0; j < 1024; j++){
                resMatrix[i][j]=(int)Math.pow(testVector[i][j]- trainData[i][j], 2);
            }
        }

        for(int i = 0; i< trainCount; i++)
        {
            dis[i][0]=0;
            dis[i][1]=i;
            for(int j = 0;j< 1024; j++){
                dis[i][0]+=resMatrix[i][j];
            }
            dis[i][0]=Math.sqrt(dis[i][0]);
        }

        //对欧氏距离数组进行升序排序
        Arrays.sort(dis,new Comparator<double[]>(){
            @Override
            public int compare(double[] a,double[] b) {
                return Double.compare(a[0], b[0]);
            }
        });

        //根据参数k确定筛选数量
        for(int i = 0; i<k; i++)
        {
            tagCnt[trainTag[(int)dis[i][1]]]++;
        }

        //找到该二进制图像的最可能的预测标签
        for(int i=0;i<10;i++){
            if(tagCnt[i]>maxTimes) //找出出现次数最多的标签
            {
                maxTimes=tagCnt[i];
                maxTag=i;
            }
        }
        return maxTag;
    }

    //识别二进制图像标签,得出准确率
    public void tagJudge(String path) throws IOException
    {
        String[] testFile = new File(path).list();
        int[] testData;
        int testCount = 0;
        int trueLabel;
        int predictLabel;
        HashMap<Integer, Double> map = new HashMap<Integer, Double>();

        if (testFile != null) {
            testCount = testFile.length;
        }

        for(int k = 3; k<= 9; k++)
        {
            int wrongCnt=0;
            for(int i = 0; i< testCount; i++)
            {
                testData = convertLongVector(path+"\\"+testFile[i]);
                trueLabel = getTag(testFile[i]);
                predictLabel = knn(testData,k);

                if(predictLabel!=trueLabel){
                    wrongCnt++;
                }

                System.out.println("第" +(i+1)+ "组数据真实值为:"+trueLabel+",真实值为:" +trueLabel + "。");
            }

            double accuracy=(1-(double)wrongCnt/testCount)*100;
            System.out.println("k值为:" +k+ "时,准确率约为:"+String.format("%.4f",accuracy)+"%。");
            map.put(k, accuracy);
        }

        List<HashMap.Entry<Integer, Double> > list = new ArrayList<Map.Entry<Integer,Double>>(map.entrySet());
        list.sort(new Comparator<HashMap.Entry<Integer, Double>>() {
            @Override
            public int compare(HashMap.Entry<Integer, Double> a, HashMap.Entry<Integer, Double> b) {
                return b.getValue().compareTo(a.getValue());
            }
        });
        System.out.println();
        System.out.println("所有k值对应的准确率如下:");

        for(HashMap.Entry<Integer, Double> ele:list)
        {
            System.out.println("k:"+ele.getKey()+"  "+"准确率约为:"+String.format("%.4f",ele.getValue())+ "%");
        }
    }

}

KnnNumberTest:

import java.io.IOException;

public class KnnNumberTest {

    public static void main(String[] args) throws IOException {
        //创建实例进行测试
        KnnNumber knn = new KnnNumber();
        knn.getTrainData("C:\\Users\\ASUS\\Desktop\\FolderOne\\digits\\trainingDigits");   //建立训练数据和对应的分类标签,我把测试集与训练集文件放在桌面的FoldOne文件夹的digits子文件夹里
        knn.tagJudge("C:\\Users\\ASUS\\Desktop\\FolderOne\\digits\\testDigits");              //识别测试数据的标签,得出准确率
    }
}

(4) 実験体験

この質問には 3 つの困難があります。

最初の困難は、トレーニング セットとテスト セットの処理にあります。IO ストリームを使用してデータの基本的な処理を実行する方法。

2 番目の困難は、ファイル コンテンツの処理、実際のラベルの取得方法、可能なラベルの予測方法、最も類似した 2 進数の決定方法、およびベクトル距離の計算方法にあります。

3 番目の困難は、さまざまな状況における k の結果の計算と結果のスクリーニングにあります。

この学習を通じて、IO フローを使用して実際的な問題を解決する方法を学び、多くの恩恵を受けました。

おすすめ

転載: blog.csdn.net/ayaishere_/article/details/128711920