版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/aaa1050070637/article/details/84537277
上周看到K-means算法,觉得挺有意思的,然后就分析了一下原理,又用JAVA实现了一下,水平有限,还请看到此博客的各路大神, 如果看到有误的地方,还请帮我纠正一下。
我给这个算法的定义:根据某种规则,将相同的或者相近的对象,存放到一起。
基本原理:
1.定义几个初始点当做基准点,
2.计算出当前的聚类,
3.根据新的聚类,确定下一个基准点,然后再次计算出新的聚类。
用到的数学基础有曼哈顿聚类距离,加权平均等。
下面贴一下代码
package com.dhc.jstestdemo.Model;
import android.util.Log;
import java.util.ArrayList;
import java.util.List;
/**
* K-means算法解析
* Created by 大漠dreamer on 2018/11/26.
*/
public class KMeans {
/**
* 需要一个集合来存放原始坐标
*/
List<Point> original = null;
Point point1 = null;
Point point2 = null;
Point point3 = null;
Point point4 = null;
Point point5 = null;
Point point6 = null;
Point point7 = null;
Point point8 = null;
/**
* 3个新的聚类
*/
List<Point> list1 = null;
List<Point> list2 = null;
List<Point> list3 = null;
Point basePointOne = null;
Point basePointTwo = null;
Point basePointThree = null;
/**
* 聚类计算的次数
*/
private static final int calculatorNumber = 2;
private int calculator = 0;
/**
* 构造函数,可以根据自己的要求,来定制化需要进行聚类的点
* 这里例子采用的是8个点, 3个初始化基准点的方法,最终计算十次之后来确定聚类
*/
public KMeans() {
}
/**
* 初始化
* 这里例子采用的是8个点, 3个初始化基准点的方法,最终来确定聚类
*/
public void initData() {
original = new ArrayList<>();
list1 = new ArrayList<>();
list2 = new ArrayList<>();
list3 = new ArrayList<>();
point1 = new Point(1.0, 2.0);
original.add(point1);
point2 = new Point(4.0, 3.0);
original.add(point2);
point3 = new Point(3.0, 5.0);
original.add(point3);
point4 = new Point(4.0, 9.0);
original.add(point4);
point5 = new Point(2.0, 10.0);
original.add(point5);
point6 = new Point(6.0, 5.0);
original.add(point6);
point7 = new Point(5.0, 2.0);
original.add(point7);
point8 = new Point(7.0, 1.0);
original.add(point8);
//选取初始点,分别计算曼哈顿聚类距离,此处选取1,4,7为初始点
basePointOne = point1;
basePointTwo = point4;
basePointThree = point7;
}
/**
* 计算点到基准点的距离,并将数据添加到对应的集合中
*
* @param point
*/
private void setPointToCluster(Point point, Point pointBase1
, Point pointBase2, Point pointBase3) {
Double distanceForOneToOne = ManHaDunDistance(point, pointBase1);
Double distanceForOneToFour = ManHaDunDistance(point, pointBase2);
Double distanceForOneToSeven = ManHaDunDistance(point, pointBase3);
Double compareOne = Math.min(distanceForOneToOne, distanceForOneToFour);
Double compareTwo = Math.min(compareOne, distanceForOneToSeven);
if (compareTwo.equals(distanceForOneToOne)) {
list1.add(point);
} else if (compareTwo.equals(distanceForOneToFour)) {
list2.add(point);
} else {
list3.add(point);
}
}
/**
* 计算下一个聚类,
*/
public void getNextBasePointAndUpdateCluster() {
calculator++;
/**
* 当递归次数已经到达限制次数之后,不再进行递归运算,计算停止
*/
if (calculator == calculatorNumber) {
return;
}
/**
* 每次计算聚类的时候,清除上一次的聚类数据
*/
if (list1 != null) {
list1.clear();
}
if (list2 != null) {
list2.clear();
}
if (list3 != null) {
list3.clear();
}
setPointToCluster(point1, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point2, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point3, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point4, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point5, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point6, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point7, basePointOne, basePointTwo, basePointThree);
setPointToCluster(point8, basePointOne, basePointTwo, basePointThree);
basePointOne = new Point(getAverage(list1, true), getAverage(list1, false));
basePointTwo = new Point(getAverage(list2, true), getAverage(list2, false));
basePointThree = new Point(getAverage(list3, true), getAverage(list3, false));
/**
* 递归继续算下一个点和聚类
*/
getNextBasePointAndUpdateCluster();
}
/**
* 计算数字的加权平均值
*/
private Double getAverage(List<Point> list, boolean isX) {
Double sum = 0.0;
for (int i = 0; i < list.size(); i++) {
if (isX) {
sum = sum + list.get(i).getX();
} else {
sum = sum + list.get(i).getY();
}
}
return sum / list.size();
}
/**
* 曼哈顿聚类距离
*/
private Double ManHaDunDistance(Point pointOne, Point pointTwo) {
return Math.abs(pointTwo.getX() - pointOne.getX()) +
Math.abs(pointTwo.getY() - pointOne.getY());
}
public void typeList() {
for (int i = 0; i < list1.size(); i++) {
Point point = list1.get(i);
int index = original.indexOf(point);
Log.d("cluster", "我来自聚类1----横坐标为:" + point.getX()
+ "纵坐标为:" + point.getY() +
"位于原始集合里面的:" + (index + 1) + "位置");
}
for (int i = 0; i < list2.size(); i++) {
Point point = list2.get(i);
int index = original.indexOf(point);
Log.d("cluster", "我来自聚类2----横坐标为:" + point.getX()
+ "纵坐标为:" + point.getY() +
"位于原始集合里面的:" + (index + 1) + "位置");
}
for (int i = 0; i < list3.size(); i++) {
Point point = list3.get(i);
int index = original.indexOf(point);
Log.d("cluster", "我来自聚类3----横坐标为:" + point.getX()
+ "纵坐标为:" + point.getY() +
"位于原始集合里面的:" + (index + 1) + "位置");
}
}
/**
* 坐标类
*/
class Point {
Double x;
Double y;
public Point(Double x, Double y) {
this.x = x;
this.y = y;
}
public Double getX() {
return x;
}
public void setX(Double x) {
this.x = x;
}
public Double getY() {
return y;
}
public void setY(Double y) {
this.y = y;
}
}
}
测试方法
private void clusterCal() {
KMeans kMeans = new KMeans();
kMeans.initData();
kMeans.getNextBasePointAndUpdateCluster();
kMeans.typeList();
}
运行结果 这里需要纠正一个错误,就是聚类运算的次数,不是自己定义的,而是计算到一定程度,聚类基准点会不再发生变化,此时,代表计算完成。需要手动判断,基准点是否还在发生变化。