Java实例_综合实践3.K-Means聚类算法

编写程序实现K-Means聚类算法 (见图 4-12)。 K-Means是一种聚类算法,后者属于机器学习中的无监督学习,用于识别给定数据集 中的若干数据簇——即每个数据所属的分类。

在这里插入图片描述

K-Means算法的主要思想如下:
(1) 任选数据集中的K个数据点作为 K个分类的初始质心 (几何中心),其中K为数据 的总分类数,由人通过观察整个数据集的分布而事先确定。
(2) 将每个数据点指派到离其最近的质心所属的分类。
(3) 重新计算每个分类的质心。
(4) 重复 (2)、(3),直至达到最大迭代次数或所有分类的质心不再变化

import java.awt.Color;
import java.awt.Graphics;
import java.util.Random;

import javax.swing.JFrame;

public class KMeans {
    int N = 300; // 数据点个数
    int K = 3;// 分类数

    double[][] points = new double[N][2];// 数据点(2列分别表示x、y坐标)
    double[][] centers = new double[K][2];// 分类质心
    double[][] distances = new double[N][K];// 数据点到分类质心的距离
    int[] kinds = new int[N];// 数据点所属的分类

    JFrame win; // 用于呈现分类结果的窗口
    final int WIN_HEIGHT = 420;// 窗口高度
    final int WIN_WIDTH = 600;// 窗口宽度

    // 分类标记和颜色
    final String[] MARKER_TEXTS = { "X", "O", "+" };
    final Color[] MARKER_COLORS = { Color.RED, Color.MAGENTA, Color.BLUE };

    // 围绕给定的K个点随机产生N个数据点
    void generatePoints() {
        // 给定K个点的坐标
        double[][] ps = { { 300, 50 }, { 480, 300 }, { 120, 200 } }; 
        Random r = new Random();// 随机数对象
        int p; 
        for (int i = 0; i < N; i++) {
            p = r.nextInt(K);// 随机选择一个给定的点
            // 围绕点p随机产生一个数据点
            points[i][0] = ps[p][0] + (r.nextBoolean() ? r.nextInt(100) : -r.nextInt(100));
            points[i][1] = ps[p][1] + (r.nextBoolean() ? r.nextInt(50) : -r.nextInt(50));
        }
    }

    // 选择K个数据点作为K个分类的初始质心
    void initCenters() {
        for (int i = 0; i < K; i++) {
            centers[i][0] = points[i * N / K][0];
            centers[i][1] = points[i * N / K][1];
        }
    }

    // 对每个数据点,按离其最近的质心进行分类
    void assignPoints() {
        double dx, dy;
        double min;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < K; j++) {// 计算点i到K个质心的距离
                dx = points[i][0] - centers[j][0];
                dy = points[i][1] - centers[j][1];
                distances[i][j] = dx * dx + dy * dy;
            }
            // 选择离其最近的质心作为点i所属的分类
            min = distances[i][0];
            kinds[i] = 0;
            for (int j = 1; j < K; j++) {
                if (distances[i][j] < min) {
                    min = distances[i][j];
                    kinds[i] = j;
                }
            }
        }
    }

    // 根据所有点的分类,计算新的质心
    void calcCenters() {
        for (int j = 0; j < K; j++) {
            centers[j][0] = 0;
            centers[j][1] = 0;
            int count = 0;
            // 统计属于分类j的点的个数
            for (int i = 0; i < N; i++) {
                if (kinds[i] == j) {
                    centers[j][0] += points[i][0];
                    centers[j][1] += points[i][1];
                    count++;
                }
            }
            // 计算分类j的质心坐标(x、y的算术均值——K-Means算法名称的由来)
            centers[j][0] /= count;
            centers[j][1] /= count;
        }
    }

    // 初始化UI(仅用于呈现结果)
    void initUI() {
        win = new JFrame("K-Means");
        win.setSize(WIN_WIDTH, WIN_HEIGHT);
        win.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        win.setResizable(false);
        win.setVisible(true);
    }

    // 绘制N个数据点及K个质心(仅用于呈现结果)
    void plot() {
        Graphics g = win.getGraphics();
        g.clearRect(0, 0, WIN_WIDTH, WIN_HEIGHT);// 清除之前绘制的点
        // 以不同标记和颜色绘制各分类中的点
        for (int i = 0; i < N; i++) {
            g.setColor(MARKER_COLORS[kinds[i]]);
            g.drawString(MARKER_TEXTS[kinds[i]], (int) points[i][0], (int) (WIN_HEIGHT - 20 - points[i][1]));
        }
        // 绘制各分类的质心
        for (int i = 0; i < K; i++) {
            g.setColor(Color.BLACK);
            g.drawString("★", (int) centers[i][0], (int) (WIN_HEIGHT - 20 - centers[i][1]));
        }
    }

    // 程序入口
    public static void main(String[] args) throws InterruptedException {
        KMeans kMeans = new KMeans();
        kMeans.generatePoints();
        kMeans.initCenters();
        kMeans.initUI();

        for (int i = 0; i < 10; i++) { // 迭代10次
            kMeans.plot();
            kMeans.assignPoints();
            kMeans.calcCenters();
            Thread.sleep(500); // 暂停0.5秒(便于观察迭代过程)
        }
    }
}

猜你喜欢

转载自blog.csdn.net/qq_46672746/article/details/107587197