K-means算法解析及代码

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/aaa1050070637/article/details/84537277

上周看到K-means算法,觉得挺有意思的,然后就分析了一下原理,又用JAVA实现了一下,水平有限,还请看到此博客的各路大神, 如果看到有误的地方,还请帮我纠正一下。

我给这个算法的定义:根据某种规则,将相同的或者相近的对象,存放到一起。

基本原理:

1.定义几个初始点当做基准点,

2.计算出当前的聚类,

3.根据新的聚类,确定下一个基准点,然后再次计算出新的聚类。

用到的数学基础有曼哈顿聚类距离,加权平均等。

下面贴一下代码

package com.dhc.jstestdemo.Model;

import android.util.Log;

import java.util.ArrayList;
import java.util.List;

/**
 * K-means算法解析
 * Created by 大漠dreamer on 2018/11/26.
 */

public class KMeans {

    /**
     * 需要一个集合来存放原始坐标
     */
    List<Point> original = null;
    Point point1 = null;
    Point point2 = null;
    Point point3 = null;
    Point point4 = null;

    Point point5 = null;
    Point point6 = null;
    Point point7 = null;
    Point point8 = null;

    /**
     * 3个新的聚类
     */
    List<Point> list1 = null;
    List<Point> list2 = null;
    List<Point> list3 = null;

    Point basePointOne = null;
    Point basePointTwo = null;
    Point basePointThree = null;

    /**
     * 聚类计算的次数
     */
    private static final int calculatorNumber = 2;
    private int calculator = 0;

    /**
     * 构造函数,可以根据自己的要求,来定制化需要进行聚类的点
     * 这里例子采用的是8个点, 3个初始化基准点的方法,最终计算十次之后来确定聚类
     */
    public KMeans() {
    }

    /**
     * 初始化
     * 这里例子采用的是8个点, 3个初始化基准点的方法,最终来确定聚类
     */
    public void initData() {
        original = new ArrayList<>();
        list1 = new ArrayList<>();
        list2 = new ArrayList<>();
        list3 = new ArrayList<>();


        point1 = new Point(1.0, 2.0);
        original.add(point1);
        point2 = new Point(4.0, 3.0);
        original.add(point2);
        point3 = new Point(3.0, 5.0);
        original.add(point3);
        point4 = new Point(4.0, 9.0);
        original.add(point4);

        point5 = new Point(2.0, 10.0);
        original.add(point5);
        point6 = new Point(6.0, 5.0);
        original.add(point6);
        point7 = new Point(5.0, 2.0);
        original.add(point7);
        point8 = new Point(7.0, 1.0);
        original.add(point8);

        //选取初始点,分别计算曼哈顿聚类距离,此处选取1,4,7为初始点
        basePointOne = point1;
        basePointTwo = point4;
        basePointThree = point7;
    }


    /**
     * 计算点到基准点的距离,并将数据添加到对应的集合中
     *
     * @param point
     */
    private void setPointToCluster(Point point, Point pointBase1
            , Point pointBase2, Point pointBase3) {

        Double distanceForOneToOne = ManHaDunDistance(point, pointBase1);
        Double distanceForOneToFour = ManHaDunDistance(point, pointBase2);
        Double distanceForOneToSeven = ManHaDunDistance(point, pointBase3);

        Double compareOne = Math.min(distanceForOneToOne, distanceForOneToFour);
        Double compareTwo = Math.min(compareOne, distanceForOneToSeven);
        if (compareTwo.equals(distanceForOneToOne)) {
            list1.add(point);
        } else if (compareTwo.equals(distanceForOneToFour)) {
            list2.add(point);
        } else {
            list3.add(point);
        }
    }

    /**
     * 计算下一个聚类,
     */
    public void getNextBasePointAndUpdateCluster() {


        calculator++;
        /**
         * 当递归次数已经到达限制次数之后,不再进行递归运算,计算停止
         */
        if (calculator == calculatorNumber) {
            return;
        }
        /**
         *  每次计算聚类的时候,清除上一次的聚类数据
         */
        if (list1 != null) {
            list1.clear();
        }
        if (list2 != null) {
            list2.clear();
        }
        if (list3 != null) {
            list3.clear();
        }

        setPointToCluster(point1, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point2, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point3, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point4, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point5, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point6, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point7, basePointOne, basePointTwo, basePointThree);
        setPointToCluster(point8, basePointOne, basePointTwo, basePointThree);

        basePointOne = new Point(getAverage(list1, true), getAverage(list1, false));
        basePointTwo = new Point(getAverage(list2, true), getAverage(list2, false));
        basePointThree = new Point(getAverage(list3, true), getAverage(list3, false));

        /**
         * 递归继续算下一个点和聚类
         */
        getNextBasePointAndUpdateCluster();

    }

    /**
     * 计算数字的加权平均值
     */
    private Double getAverage(List<Point> list, boolean isX) {

        Double sum = 0.0;

        for (int i = 0; i < list.size(); i++) {
            if (isX) {
                sum = sum + list.get(i).getX();
            } else {
                sum = sum + list.get(i).getY();
            }
        }

        return sum / list.size();
    }

    /**
     * 曼哈顿聚类距离
     */
    private Double ManHaDunDistance(Point pointOne, Point pointTwo) {

        return Math.abs(pointTwo.getX() - pointOne.getX()) +
                Math.abs(pointTwo.getY() - pointOne.getY());
    }

    
    public void typeList() {
        for (int i = 0; i < list1.size(); i++) {
            Point point = list1.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我来自聚类1----横坐标为:" + point.getX()
                    + "纵坐标为:" + point.getY() +
                    "位于原始集合里面的:" + (index + 1) + "位置");
        }
        for (int i = 0; i < list2.size(); i++) {
            Point point = list2.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我来自聚类2----横坐标为:" + point.getX()
                    + "纵坐标为:" + point.getY() +
                    "位于原始集合里面的:" + (index + 1) + "位置");
        }
        for (int i = 0; i < list3.size(); i++) {
            Point point = list3.get(i);
            int index = original.indexOf(point);
            Log.d("cluster", "我来自聚类3----横坐标为:" + point.getX()
                    + "纵坐标为:" + point.getY() +
                    "位于原始集合里面的:" + (index + 1) + "位置");
        }
    }

    /**
     * 坐标类
     */
    class Point {

        Double x;
        Double y;

        public Point(Double x, Double y) {
            this.x = x;
            this.y = y;
        }

        public Double getX() {
            return x;
        }

        public void setX(Double x) {
            this.x = x;
        }

        public Double getY() {
            return y;
        }

        public void setY(Double y) {
            this.y = y;
        }
    }
}

测试方法

 private void clusterCal() {
        KMeans kMeans = new KMeans();
        kMeans.initData();
        kMeans.getNextBasePointAndUpdateCluster();
        kMeans.typeList();
    }

运行结果  这里需要纠正一个错误,就是聚类运算的次数,不是自己定义的,而是计算到一定程度,聚类基准点会不再发生变化,此时,代表计算完成。需要手动判断,基准点是否还在发生变化。

猜你喜欢

转载自blog.csdn.net/aaa1050070637/article/details/84537277