データ マイニング Java - KNN アルゴリズムの実装

1. KNN アルゴリズムの予備知識

k 近傍(kNN、k-NearestNeighbor) は、トレーニング セット内の入力データ ポイントから k 個の最近傍を選択し、k 近傍の中で最も多く出現したカテゴリ (最大投票ルール) をそのカテゴリとして使用します。データポイント。

分類はデータマイニングにおいて非常に重要なタスクです。分類の目的は、データベース内のデータ項目を特定のタイプの特定のカテゴリにマッピングできる分類関数または分類モデル (分類子とも呼ばれる) を学習することです。分類は予測に使用できます。予測の目的は、過去のデータ記録から特定のデータの傾向の説明を自動的に導き出し、将来のデータを予測できるようにすることです。統計で一般的に使用される予測方法は回帰です。データマイニングにおける分類と統計における回帰法は、相互に関連しているものの、異なる概念のペアです。一般に、分類の出力は離散カテゴリ値ですが、回帰の出力は連続値です。

類似性: データベース D={t1,t2,…,tn} とクラスのセット C={C1,C2,…,Cm} が与えられるとします。任意のタプル ti={ti1,ti2,…,tik}∈D に対して、sim(ti,Cj)≥sim(ti,Cp) となる Cj∈C が存在する場合、Cp∈C、Cp≠Cj、が存在します。次に、ti がクラス Cj に割り当てられます。ここで、sim(ti,Cj) は類似度と呼ばれます。実際の計算では距離で表すことが多く、距離が近いほど類似度は大きくなり、距離が遠いほど類似度は小さくなります。

類似性を計算するには、まず各クラスを表すベクトルを取得する必要があります。計算方法は多数あり、例えば各クラスの中心を計算することで各クラスを表すベクトルを表すことができる。また、パターン認識では、各クラスを表すためにあらかじめ定義された画像が使用され、分類は、分類されるサンプルとあらかじめ定義された画像とを比較することです。

2. KNNアルゴリズムの基本的な考え方

KNN アルゴリズムの考え方は比較的単純です。各クラスに複数のトレーニング データが含まれており、各トレーニング データに一意のカテゴリ ラベルがあると仮定すると、KNN アルゴリズムの主な考え方は、各トレーニング データと分類されるタプルの間の距離を計算し、最も近い距離を取得することです。分類対象のタプル k 個のデータのうち k 個の訓練データ、k 個のデータの中でどのカテゴリーの訓練データが多いか、分類対象のタプルはどのカテゴリーに属するか。

3. KNNアルゴリズムと強相関ルールマイニングの例

KNN アルゴリズムの例
ここに画像の説明を挿入
ここに画像の説明を挿入

4. KNNアルゴリズムの実装プロセス

実験内容
身長と学年が登録されているクラスの生徒は 14 人で、新入生のイー・チャン君は身長 174cm、学年は何年生ですか。分類認識には knn アルゴリズム (k=5) を使用してください。
ここに画像の説明を挿入

実験アイデア
(1) 学生クラス Student を定義し、学生クラスの名前、身長、学年などの属性を定義し、lombok が依存する @Data アノテーションを使用して Student クラスの get メソッドと set メソッドを注入します。初期データ セットを定義し、14 個のエンティティ Student クラスを定義し、14 個のエンティティ Student クラスを初期データ セット dataList に追加します。
(2) initData() メソッドを呼び出してデータ セットを初期化し、Student クラス stuV0 を定義してその名前と高さを入力としてインスタンス化し、Knn() メソッドを呼び出して Student クラス オブジェクト Student と成績を取得し、オブジェクトを実行します。学生のアウトプット。
(3) Knn() メソッド本体内で、データ セットの最初の 5 項目が最初に categoryList コレクションに追加されます。categoryList コレクションは、stuV0 に最も近い k 人の生徒と、データの最初の 5 項目のみを保存するために使用されます。セットが最初に保存されます。データ セット dataList を走査し、stuV0 とデータ セットの項目 6 から始まる各項目の間の距離 v0Tod を計算し、getCalculate() メソッドを呼び出して、stuV0 と categoryList コレクション内の学生オブジェクト stuU の間の距離を計算します (stuU がstuV0 の高さ 距離 uToV0 が v0Tod より大きい場合、categoryList から stuU を削除し、データ セット内の項目を categoryList コレクションに追加します。
(4) getCalculate() メソッドの本体内で、stuV0 とカテゴリ セット categoryList の間の最も遠い距離を格納する変数 maxHeight を定義し、返される学生、つまり最も遠い距離を格納する Student クラス オブジェクト resultStu を定義します。 stuV0 とカテゴリ セット categoryList 学生の間の距離。categoryList コレクションを走査し、stuU と stuV0 の間の距離が maxHeight より大きい場合は、v0ToU を maxHeight に割り当て、stuU を resultStu に割り当て、最後に Student クラス オブジェクト resultStu を返します。
(5) getCategoryStudent() メソッドを呼び出して、categoryList 内で同じ学年の割合が最も大きい学生のランクを見つけ、最後に stuV0 のランク属性をランクでインスタンス化し、stuV0 を返します。
(6) getCategoryStudent() メソッドの本体で、categoryList をトラバースして、同じ学年の割合が最も大きい生徒の成績を見つけます。実際には、最高の成績、中間の成績、および最高の成績を持つ生徒を見つけることになります。 short Grade (どのカテゴリに最も多くの学生がいるか) を返し、最も多くの学生がいるカテゴリを返します。

ソースコードを実現する

学生类
package com.data.mining.entity;

import lombok.Data;

@Data
public class Student {
    
    
    private String name;
    private int height;
    private String rank;

    public Student(){
    
    }

    public Student(String n, int h){
    
    
        name = n;
        height = h;
    }

    public Student(String n, int h, String r){
    
    
        name = n;
        height = h;
        rank = r;
    }
}

KNN算法实现代码
package com.data.mining.main;

import com.data.mining.entity.Student;

import java.util.ArrayList;
import java.util.List;

public class Knn {
    
    
    //定义初始数据集
    public static List<Student> dataList = new ArrayList<>();

    public static void main(String[] args) {
    
    
        initData();
        Student stuV0 = new Student("易昌", 174);
        Student student = Knn(stuV0);
        System.out.println(student.toString());
    }

    /**
     * 找出同等级占比最多的学生等级
     * @param categoryList
     * @return
     */
    public static String getCategoryStudent(List<Student> categoryList){
    
    
        int tallCount = 0;
        int midCount = 0;
        int smallCount = 0;
        for (Student stuU : categoryList) {
    
    
            if (stuU.getRank().equals("高")) tallCount++;
            else if (stuU.getRank().equals("中等")) midCount++;
            else smallCount++;
        }
        int max = 0;
        max = tallCount > midCount ? tallCount : midCount;
        max = smallCount > max ? smallCount : max;
        if (smallCount == max) return "矮";
        else if (tallCount == max) return "高";
        else return "中等";
    }

    /**
     * 计算出stuV0距离categoryList集合中最远的学生对象
     * @param stuV0
     * @param categoryList
     * @return
     */
    public static Student getCalculate(Student stuV0, List<Student> categoryList) {
    
    
        int maxHeight = 0; //存放stuV0与类别集合categoryList的最远距离
        Student resultStu = new Student(); //存放要返回的学生,即stuV0与类别集合categoryList距离最远的学生
        for (Student stuU : categoryList) {
    
    
            int v0ToU = Math.abs(stuV0.getHeight() - stuU.getHeight()); //stuV0与stuU的距离
            if (v0ToU > maxHeight){
    
     //stuV0与stuU的距离大于maxHeight,则对maxHeight和resultStu进行更新
                maxHeight = v0ToU;
                resultStu = stuU;
            }
        }
        return resultStu;
    }

    /**
     * 对输入学生类进行Knn算法实例化该学生的等级后,将该学生返回
     * @param stuV0
     * @return
     */
    public static Student Knn(Student stuV0){
    
    
        List<Student> categoryList = new ArrayList<>(); //存放距离stuV0最近的k个学生,最初存放数据集的前5项
        for (int i = 0; i < dataList.size(); i++) {
    
    
            if (i < 5) categoryList.add(dataList.get(i));
            else {
    
    
                //stuV0距离剩下数据集中某项的距离
                int v0Tod = Math.abs(stuV0.getHeight() - dataList.get(i).getHeight());
                Student stuU =  getCalculate(stuV0, categoryList); //存放stuV0距离类别集合中最远的学生
                int uToV0 = Math.abs(stuU.getHeight() - stuV0.getHeight());
                if (uToV0 > v0Tod){
    
    
                    categoryList.remove(stuU); //在集合列表中去除stuU
                    categoryList.add(dataList.get(i));
                }
            }
        }
        System.out.println(categoryList.toString());
        String rank = getCategoryStudent(categoryList);
        stuV0.setRank(rank);

        return stuV0;
    }


    /**
     * 初始化数据
     */
    public static void initData(){
    
    
        Student s1 = new Student("李丽", 150, "矮");
        Student s2 = new Student("吉米", 192, "高");
        Student s3 = new Student("马大华", 170, "中等");
        Student s4 = new Student("王晓华", 173, "中等");
        Student s5 = new Student("刘敏", 160, "矮");
        Student s6 = new Student("张强", 175, "中等");
        Student s7 = new Student("李秦", 160, "矮");
        Student s8 = new Student("王壮", 190, "高");
        Student s9 = new Student("刘冰", 168, "中等");
        Student s10 = new Student("张喆", 178, "中等");
        Student s11 = new Student("杨毅", 170, "中等");
        Student s12 = new Student("徐田", 168, "中等");
        Student s13 = new Student("高杰", 165, "矮");
        Student s14 = new Student("张晓", 178, "中等");

        dataList.add(s1);
        dataList.add(s2);
        dataList.add(s3);
        dataList.add(s4);
        dataList.add(s5);
        dataList.add(s6);
        dataList.add(s7);
        dataList.add(s8);
        dataList.add(s9);
        dataList.add(s10);
        dataList.add(s11);
        dataList.add(s12);
        dataList.add(s13);
        dataList.add(s14);
    }
}

実験結果
ここに画像の説明を挿入
さらに、質問で必要な生徒の身長グレードを出力することに加えて、作成者は、結果が正しいことを確認するために質問を比較するために、入力された生徒が位置するクラスターも出力します。
ここに画像の説明を挿入

5. 実験の概要

著者は、この実験の結果が正しいことを保証するものではなく、Java 言語を使用して KNN アルゴリズムを実装するためのアイデアを提供しているだけです。実験では答えが得られなかったので、インターネット上にある答え付きの実験データをプログラムに入力し、プログラムが出力した結果は答えと一致しているので、問題は大きくないはずです。うまく書いていないところがあれば、ぜひご指摘ください!
著者のホームページには、他のデータ マイニング アルゴリズムの概要も掲載されています。ぜひご利用ください。

おすすめ

転載: blog.csdn.net/qq_54162207/article/details/128366181