k-Means算法,聚类算法

最近看随机游走算法时,遇到了聚类算法。 结合当时参考的论文和一些博客。 整理了下思维,写了下面的算法。


package kMeans;

import java.util.Arrays;


/**   
 * @ClassName:  KMeansM   
 * @Description:k-Means算法,聚类算法
  实现步骤:   1. 首先是随机获取总体中的K个元素作为总体的K个中心;  
  2. 接下来对总体中的元素进行分类,每个元素都去判断自己到K个中心的距离,并归类到最近距离中心去;  
  3. 计算每个聚类的平均值,并作为新的中心点  
  4. 重复2,3步骤,直到这k个中线点不再变化(收敛了),或执行了足够多的迭代  
 * @author: muliming 
 * @date:   2017年11月21日 下午11:12:11     
 * @Copyright: 2017 
 */
public class KMeansM {
//定义模拟数据源  一个三维的点   选择18个点
private static double[][] DATA = { 
{5.1,3.5,1.4},{4.9,3.0,1.4},{4.7,3.2,1.3},{4.6,3.1,1.5},{5.0,3.6,1.4},{5.4,3.9,1.7},
{4.6,3.4,1.4},{5.0,3.4,1.5},{4.4,2.9,1.4},{4.9,3.1,1.5},{5.4,3.7,1.5},{4.8,3.4,1.6},
{4.8,3.0,1.4},{4.3,3.0,1.1},{5.8,4.0,1.2},{5.7,4.4,1.5},{5.4,3.9,1.3},{5.1,3.5,1.4},
{5.7,3.8,1.7},{5.1,3.8,1.5},{5.4,3.4,1.7},{5.1,3.7,1.5},{4.6,3.6,1.0},{5.1,3.3,1.7},
{4.8,3.4,1.9},{5.0,3.0,1.6},{5.0,3.4,1.6},{5.2,3.5,1.5},{5.2,3.4,1.4},{4.7,3.2,1.6},
{4.8,3.1,1.6},{5.4,3.4,1.5},{5.2,4.1,1.5},{5.5,4.2,1.4},{4.9,3.1,1.5},{5.0,3.2,1.2},
{5.5,3.5,1.3},{4.9,3.1,1.5},{4.4,3.0,1.3},{5.1,3.4,1.5},{5.0,3.5,1.3},{4.5,2.3,1.3},
{4.4,3.2,1.3},{5.0,3.5,1.6},{5.1,3.8,1.9},{4.8,3.0,1.4},{5.1,3.8,1.6},{4.6,3.2,1.4},
{5.3,3.7,1.5},{5.0,3.3,1.4}};

public int k;//选择k个中心点进行聚合 
public int[][] tempClusters;//记录每个中心点下点的索引号
public int[] elementsInCenters; //记录每个中心点所属的各自的类的个数
public double[][] centers; //中心点
public int[] memberShip; //记录每个点的中心点的索引号  为了后面的比较方便

    //有参构成,初始化中心点的个数
    public KMeansM(int k){  
        this.k = k; //定义中心点的个数
    }


//--主函数--------------------------------------------------------
    public static void main(String[] args) {  
    KMeansM kmeansM = new KMeansM(3); //初始化点数
        String lastMembership = "";  
        String nowMembership = "";  
        int i=0;  
        kmeansM.firstCenters();//初试化中心点
        System.out.println("第一次选取得中心点为:");  
        for(int n=0;n<kmeansM.centers.length;n++){
        System.out.print(Arrays.toString(kmeansM.centers[n])+",");
        }
        System.out.println();
        boolean isEnd=true;
        while(isEnd){  
            i++;  
            kmeansM.calMemberShip(); //归类  寻找每个点所属的中心点,并记录每个中心点自己最终含有的点的总数 
            nowMembership = Arrays.toString(kmeansM.memberShip);//把当前的赋值,记录
            if(nowMembership.equals(lastMembership)){  
            System.out.println("");
                System.out.println("一共聚类了 "+(i-1)+" 次!");  
                for(int n=0;n<kmeansM.centers.length;n++){
                System.out.print(Arrays.toString(kmeansM.centers[n])+",");
                }
                isEnd=false;  
            }else{  
            kmeansM.calNewCenters();
                lastMembership = nowMembership;
                System.out.println("第 "+i+" 次聚合");  
                for(int n=0;n<kmeansM.centers.length;n++){
                System.out.print(Arrays.toString(kmeansM.centers[n])+",");
                }
                System.out.println();
                System.out.println();
            }  
        }  
    }
    
    //-------工具方法---------------------------------------------------------
    
    /**
     * @Title: firstCenters   
     * @Description: 初试化中心点,选取前k个点为初试中心点   
     * @return: 创建初始化中心点      
     * @throws
     */
    public void firstCenters(){
    centers = new double[k][DATA[0].length];  
        for(int i=0;i<k;i++){  
            for(int j=0;j<DATA[i].length;j++){
            centers[i][j]=DATA[i][j];
            }
        }  
    }
    
    /**
     * @Title: manhattanDistince   
     * @Description: 计算临近距离   每一串数据与选出的中心点的距离,   
     * @param: @param paraFirstData 第一个点
     * @param: @param paraSecondData  第二个点
     * @return: double  返回两个点之间的距离 
     * @throws
     */
    public double manhattanDistince(double[] paraFirstData,double[] paraSecondData){  
        double tempDistince = 0;  
        if((paraFirstData!=null && paraSecondData!=null) && paraFirstData.length==paraSecondData.length){  
            for(int i=0;i<paraFirstData.length;i++){  
                tempDistince += Math.abs(paraFirstData[i] - paraSecondData[i]);  
            }  
        }else{  
            System.out.println("firstData 与 secondData 数据结构不一致");  
        }  
        return tempDistince;  
    } 
    
    /**
     * @Title: calNewCenters   
     * @Description: 生成新的中心点  ,每次中心点取该类中的平均值
     */
    public void calNewCenters(){
    double[][] tempCenters=new double[k][3]; //中心点
      //求和
        for(int i=0;i<k;i++){  
            for(int j=0;j<elementsInCenters[i];j++){  
               for(int k=0;k<DATA[i].length;k++){
              tempCenters[i][k]+=DATA[tempClusters[i][j]][k];
               }
            }  
        }  
      //取平均值
        for(int i=0;i<centers.length;i++){  
            for(int j=0;j<DATA[0].length;j++){  
            if(elementsInCenters[i]!=0){
            tempCenters[i][j] /= elementsInCenters[i];  
            }else{
            tempCenters[i][j] = centers[i][j];  
}
            }  
        }
        centers=tempCenters;
    } 
    
 
    /**
     * @Title: calMemberShip   
     * @Description: 寻找每个点所属的中心点,并记录每个中心点自己最终含有的点的总数  
     * @return: 记录数组,和记录总数的数组      
     */
    public void calMemberShip(){  
        memberShip = new int[DATA.length];//记录每串数据的中心点在中心点数组中的索引
        tempClusters = new int[k][DATA.length];//记录每个中心点 下点的索引号
        elementsInCenters = new int[k];//记录每个中心点的类的个数  
        for(int j=0;j<DATA.length;j++){  
            double currentDistance = Double.MAX_VALUE;//比较变量
            int currentIndex = -1;//索引位置
            double[] item = DATA[j];  
            int i;
            for(i=0;i<k;i++){//和中心点做比较
                double[] tempCentersValue = centers[i]; //中心点 
                double distance = this.manhattanDistince(item, tempCentersValue);  
                if(distance<currentDistance){  
                    currentDistance = distance;  
                    currentIndex = i; //记录当前点的中心点 
                }  
            }
            memberShip[j]=currentIndex;
            tempClusters[currentIndex][elementsInCenters[currentIndex]] = j;// 把索引号存入自己的腰包
            elementsInCenters[currentIndex]++;
        }  
    } 

}

(思路可能有相似的,毕竟算法就这么点代码。有参考一些,但本算法有一改进)

猜你喜欢

转载自blog.csdn.net/qq_27731689/article/details/78630038