データ マイニング Java - Kmeans アルゴリズムの実装

1. K 平均法アルゴリズムの予備知識

k-means アルゴリズム は、k-means または k-means とも呼ばれ、最も広く使用されているクラスタリング アルゴリズムの 1 つです。類似性の計算は、クラスター内のオブジェクトの平均値に基づいて行われます。アルゴリズムは最初に k 個のオブジェクトをランダムに選択し、各オブジェクトは最初にクラスターの平均または中心を表します。残りの各オブジェクトについて、各クラスターの中心からの距離に従って、最も近いクラスターに割り当てます。次に、各クラスターの平均を再計算します。このプロセスは、基準関数が収束するまで繰り返されます。

クラスタリングとは、データ オブジェクトを複数のクラスまたはクラスターにグループ化することです。分割の原理は、同じクラスター内のオブジェクトは高度な類似性を持ち、異なるクラスター内のオブジェクトはまったく異なるということです。分類とは異なり、クラスタリング操作で分割されるクラスは事前に不明であり、クラスの形式は完全にデータ駆動型であり、教師なし学習手法に属します。

クラスター分析は、データマイニング、統計、機械学習、パターン認識などを含む多くの研究分野に由来します。これはデータ マイニングの機能ですが、データ分布を取得したり、各クラスターの特性を要約したり、特定のクラスターに焦点を当ててさらに分析したりするための独立したツールとしても使用できます。さらに、クラスター分析は、結果として得られるクラスターに対して動作する他の分析アルゴリズム (相関ルール、分類など) の前処理ステップとしても使用できます。

クラスタリング: クラスタリングは、いくつかの点で類似しているデータ メンバーを分類および整理するプロセスです。クラスタリングは、この内部構造を発見するための手法です。クラスタリング手法は、多くの場合、教師なし学習と呼ばれます。
K-means クラスタリング: K-means クラスタリングは、最も有名なパーティショニング クラスタリング アルゴリズムであり、そのシンプルさと効率性により、すべてのクラスタリング アルゴリズムの中で最も広く使用されています。データ ポイントのセットと必要なクラスター数 k、k がユーザーによって指定されると、k 平均法アルゴリズムは、特定の距離関数に従ってデータを k クラスターに繰り返し分割します。

2. K-meansアルゴリズムの基本的な考え方

K 平均法クラスタリング アルゴリズムは、最初に K 個のオブジェクトを初期クラスター中心としてランダムに選択します。次に、各オブジェクトと各シード クラスターの中心の間の距離を計算し、各オブジェクトをそれに最も近いクラスターの中心に割り当てます。クラスターの中心とそれに割り当てられたオブジェクトはクラスターを表します。すべてのオブジェクトが割り当てられると、各クラスターのクラスター中心がクラスター内の既存のオブジェクトに基づいて再計算されます。このプロセスは、特定の終了条件が満たされるまで繰り返されます。終了条件は、別のクラスターに再割り当てされるオブジェクトがない (または最小数)、クラスター中心が再び変更されない (または最小数)、誤差の二乗合計が局所的に最小であることです。

3. K-means アルゴリズムの例

K 平均法アルゴリズムの例
ここに画像の説明を挿入
ここに画像の説明を挿入

4. K-meansアルゴリズムの実装プロセス

実験内容
以下の表のデータに対して k-mean クラスタリングを実行してください。距離はユークリッド距離、k=3 です。
ここに画像の説明を挿入
実験アイデア
(1) 静的メソッド getIsSame を含む、横座標 x および縦座標 y などの属性を含む Point クラスを定義します。 (): 2 つの Point クラス オブジェクトが同じかどうかを判断します。calculateDistance() メソッド: 2 つの Point クラス オブジェクト間の距離 (ユークリッド距離) を計算します。calculateMHDDistance() メソッド: 2 つの Point クラス オブジェクト間の距離 (マンハッタン距離) を計算します。属性コア ポイント corePoint とクラスター SameList 内のすべてのポイントのコレクションを含む Cluster クラスを定義します。
(2) 初期データセット dataList を定義し、クラスターの数 k を定義し、initDataList() メソッドを呼び出してデータセットを初期化し、getInitCluster() メソッドを呼び出してクラスターを初期化します。getInitCluster() メソッドの主な機能は、初期クラスター中心として任意の k オブジェクトを取得し、k クラスターを含むセットを返すことです。getInitCluster() メソッドの本体内で、k 個のクラスターを格納するためのクラスターリスト コレクションを定義し、getRandomArray() メソッドを呼び出して、k 個の非繰り返し乱数を含む配列randomArray を取得し、データ セット内の k オブジェクトの添字を格納します。配列randomArray、traverse 配列randomArrayは、対応するクラスタのコアオブジェクトポイントとして任意の添え字を持つk Pointオブジェクトを取り出し、各サイクルの定義とインスタンス化の後にクラスタをclusterListに追加し、最後にclusterListコレクションを返します。
(3) while ループに入り、データセット dataList 内の各ポイントを走査し、getBelongCluster() メソッドを呼び出して、そのポイントが属するクラスターのクラスターリスト内の添字インデックスを取得し、指定された添字インデックスを持つクラスターを取り出します。クラスタリストを作成し、ポイントポイントをクラスタのsameListに追加します。次に、データ セットを走査した後、calculateClusterCore() メソッドを呼び出して新しいクラスターの中心を計算し、クラスター セット内の各クラスターの点セットが変更されたかどうかを確認します。変更がない場合は、while ループから抜け出し、次のことを示します。 K 平均法クラスタリング クラスは終了します。それ以外の場合は、次の while ループに入ります。データ セットを走査する前に、clusterList コレクション内の各クラスターの SameList コレクションをクリアする必要があります。
(4) clusetrList コレクションを走査し、コレクション内の各項目クラスターを出力します。
(5) getBelongCluster() メソッドの主な機能は、特定の点がどのクラスターに属する添え字を取得することです。メソッド本体内で、変数 ClosestDistance と変数 resultClusterIndex が定義され、それぞれポイントとクラスター中心間の最も近い距離と、ポイントが属するクラスターの添え字が格納されます。クラスター コレクション clusterList を走査し、Point クラスの静的メソッド CalculateDistance() を呼び出して、ポイントとクラスター クラスターのコア ポイントの間の距離を計算して距離に割り当て、最初の走査で取得した距離値を距離に割り当てます。後続の走査で距離が最も近いDistanceより小さい場合、距離を最も近いDistanceに割り当て、同時にインデックスをresultClusterIndexに割り当て、ループの走査は終了し、最後にresultClusterIndexを返します。
(6) CalculateClusterCore() メソッドの主な機能は、新しいクラスターの中心を計算し、クラスターの点セットが変更されたかどうかを返すことです。メソッド本体内でフラグ変数 flag を定義し、clusterList コレクション内の各項目クラスターを走査し、変数 sumX と変数 sumY を定義して、クラスターの中心点セットのすべての x 座標の合計と、クラスターの中心点セットのすべての y 座標を格納します。クラスター中心点セット Sum、sumX と sumY の平均値を計算し、それを新しいクラスター中心点に割り当てます。clusterCore を割り当てます。Point クラスの静的メソッド getIsSame() を呼び出して、clusterCore が元のクラスター中心と同じかどうかを判断します。そうでない場合は、フラグを true に設定します。クラスター セットを走査した後、フラグ値が返されます。ここでの仮パラメータの型は List コレクションであり、List コレクションのアドレスが渡されることに注意してください。メソッド本体でコレクションを変更すると、実パラメータの値が変更されます。

ソースコードを実現する

Cluster类
package com.data.mining.entity;

import lombok.Data;

import java.util.ArrayList;
import java.util.List;

@Data
public class Cluster {
    
    
    private Point corePoint;
    private List<Point> sameList = new ArrayList<>();

    public Cluster(){
    
    }

    public Cluster(Point cp){
    
    
        corePoint = cp;
    }
}

Point类
package com.data.mining.entity;

import lombok.Data;

@Data
public class Point {
    
    
    private double x;
    private double y;

    public Point(){
    
    }

    public Point(double x, double y){
    
    
        this.x = x;
        this.y = y;
    }

    public static boolean getIsSame(Point p1, Point p2){
    
    
        if (p1.getX() == p2.getX() && p1.getY() == p2.getY()) return true;
        return false;
    }

    public static double calculateDistance(Point p1, Point p2){
    
    
        double xDistance = p1.getX() - p2.getX();
        double yDistance = p1.getY() - p2.getY();
        double tmp = xDistance * xDistance + yDistance * yDistance;
        return Math.sqrt(tmp);
    }

    public static double calculateMHDDistance(Point p1, Point p2){
    
    
        return Math.abs(p1.getX() - p2.getX()) + Math.abs(p1.getY() - p2.getY());
    }

}

K-means算法实现代码
package com.data.mining.main;

import com.data.mining.entity.Cluster;
import com.data.mining.entity.Point;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class Kmeans {
    
    
    //定义初始数据集
    public static List<Point> dataList = new ArrayList<>();
    //定义簇的数目
    public static Integer k = 3;

    public static void main(String[] args) {
    
    
        //初始化数据集和初始簇
        initDataList();
        List<Cluster> clusterList = getInitCluster();
        while(true){
    
    
            for (int j = 0; j < k; j++) {
    
    
                clusterList.get(j).getSameList().clear();
            }
            for (Point point : dataList) {
    
    
                int index = getBelongCluster(point, clusterList); //获取point属于的那个簇在clusterList中的下标
                clusterList.get(index).getSameList().add(point); //把point加入到clusterList的对应簇中;
            }
            if (!calculateClusterCore(clusterList)) break;
        }
        for (Cluster cluster : clusterList) {
    
    
            System.out.println(cluster);
        }
    }

    /**
     * 计算出新的簇中心并返回簇的点集合是否有变化
     * @param clusterList
     * @return
     */
    public static boolean calculateClusterCore(List<Cluster> clusterList){
    
    
        boolean flag = false;
        //遍历簇集合中的每一项,更新其簇中心
        for (Cluster cluster : clusterList) {
    
    
            List<Point> sameList = cluster.getSameList();
            double sumX = 0; //存放簇中点集合所有的X坐标之和
            double sumY = 0; //存放簇中点集合所有的Y坐标之和
            for (Point point : sameList) {
    
    
                sumX += point.getX();
                sumY += point.getY();
            }
            //更新簇的中心
            Point clusterCore = new Point(sumX * 1.0 / sameList.size(), sumY * 1.0 / sameList.size());
            if (!Point.getIsSame(clusterCore, cluster.getCorePoint())) flag = true;
            cluster.setCorePoint(clusterCore);
        }
        return flag;
    }

    /**
     * 获取某个点属于哪个簇的下标
     * @param point
     * @return
     */
    public static int getBelongCluster(Point point, List<Cluster> clusterList){
    
    
        double closestDistance = 0.0; //存放point距离簇中心最近的距离
        int resultClusterIndex = 0; //存放point属于的那个簇的下标
        int index = 0;
        //遍历簇集合,计算point到簇中心的距离,找出point属于的簇
        for (Cluster cluster : clusterList) {
    
    
            double distance = Point.calculateDistance(point, cluster.getCorePoint());
            if (index == 0) closestDistance = distance;
            if (distance < closestDistance){
    
    
                closestDistance = distance;
                resultClusterIndex = index;
            }
            index++;
        }
        return resultClusterIndex;
    }

    /**
     * 获取任意k个对象作为初始簇中心,将含有k个簇的集合返回
     * @return
     */
    public static List<Cluster> getInitCluster(){
    
    
        List<Cluster> clusterList = new ArrayList<>();
        int[] randomArray = getRandomArray();
        //任意选取k个对象作为初始簇中心,数据集中k个对象的下标存放在randomArray中
        for (int i = 0; i < randomArray.length; i++) {
    
    
            Point point = dataList.get(randomArray[i]);
            Cluster cluster = new Cluster(point);
            clusterList.add(cluster);
        }
        return clusterList;
    }

    /**
     * 获取含有k个不重复随机数的数组
     * @return
     */
    public static int[] getRandomArray(){
    
    
        Random random = new Random();
        int[] randomArray = new int[k];
        for (int i = 0; i < k; i++) {
    
    
            int randomItem = random.nextInt(12);
            //为保证randomArray中存放的随机数不重复
            while (Arrays.binarySearch(randomArray, randomItem) >= 0) randomItem = random.nextInt(12);
            randomArray[i] = randomItem;
        }
        return randomArray;
    }

    /**
     * 初始化数据集
     */
    public static void initDataList(){
    
    
        Point p1 = new Point(1, 2);
        Point p2 = new Point(2, 1);
        Point p3 = new Point(2, 4);
        Point p4 = new Point(4, 3);
        Point p5 = new Point(5, 8);
        Point p6 = new Point(6, 7);
        Point p7 = new Point(6, 9);
        Point p8 = new Point(7, 9);
        Point p9 = new Point(9, 5);
        Point p10 = new Point(1, 12);
        Point p11 = new Point(3, 12);
        Point p12 = new Point(5, 12);
        Point p13 = new Point(3, 3);

        dataList.add(p1);
        dataList.add(p2);
        dataList.add(p3);
        dataList.add(p4);
        dataList.add(p5);
        dataList.add(p6);
        dataList.add(p7);
        dataList.add(p8);
        dataList.add(p9);
        dataList.add(p10);
        dataList.add(p11);
        dataList.add(p12);
        dataList.add(p13);
    }
}


実験結果
ここに画像の説明を挿入
k の値が 3 であるため、出力結果は間違いなく 3 クラスターになります。ここで、著者はいくつかのテストを実行しましたが、テストの数が増えるにつれてテスト結果が異なることがわかりました。データを収集した後、著者は個人的にこの状況が正常であると考えています。その理由は、クラスター センターがクラスター センターでランダムに選択されるためです。最初に選択したクラスターの中心点がコンパクトすぎたり、避難しすぎたりする可能性があり、最終的な出力結果に影響を与える可能性があります。多くのテストの結果、作成者は、次の頻度が最も高い出力結果のグループがあることを発見しました。この出力結果のグループを図に示します。この写真の文字がなぜこんなに小さいのかわかりません。とにかく、はっきりと読めないので、表を使って記入してください。
ここに画像の説明を挿入

5. 実験の概要

著者は、この実験の結果が正しいことを保証するものではなく、Java 言語を使用して K 平均法アルゴリズムを実装する方法を提供しているだけです。実験では答えが得られなかったので、著者は本に載っている答え付きの実験データをプログラムに入力し、プログラムが出力する結果は答えと一致しているので、問題は大きくないはずです。うまく書いていないところがあれば、ぜひご指摘ください!
著者のホームページには、他のデータ マイニング アルゴリズムの概要も掲載されています。ぜひご利用ください。

おすすめ

転載: blog.csdn.net/qq_54162207/article/details/128366655