kmeans++

kmeans++是kmean算法的改进,原来kmean算法在原始数据的最小、最大区间内均等的选择K个聚类中心,然而kmeans++却是从原始数据中选择K个作为初始聚类中心,这种思路的效果优于kmean.下面是kmeas++的matlab代码,后面实例中使用的Gauss_sample函数是我的上一篇博文。

1.函数

function [K_center,seed]=kmeans_plus(data,K,LOOP)
%此代码对应kmeans++算法,该算法是对kmeans的改进,性能优于kmeans
% data是待分类的数据,尺寸为 (dim*N),dim代表每个样本的维数,N代表样本个数
% K 是设定的聚类数量,即把N个样本分成K类,每一类内部的差别小,各类之间的差别大

%首先,确定K个初始聚类中心,这些聚类中心均来源于原始的N个样本
[dim,N]=size(data);
K_center=zeros(dim,K);

seed=[];
str=round(rand(1,1)*N);
seed=[seed,data(:,str)];  %随机选择的第一个聚类中心
for k=2:1:K
    num_seed=size(seed,2);
    Dis_temp=zeros(k,N);
    d_min=zeros(1,N);
    for i=1:1:N
        d_min(1,i)=(seed(:,1)-data(:,i))'*(seed(:,1)-data(:,i));
        for j=1:1:num_seed
            Dis_temp(j,i)=(seed(:,j)-data(:,i))'*(seed(:,j)-data(:,i));
            if d_min(1,i)>Dis_temp(j,i)
                d_min(1,i)=Dis_temp(j,i);
            end
        end
    end
    pos=find(d_min==max(d_min));
    seed=[seed,data(:,pos)];
end

K_center=seed;
%所有的K个初始聚类中心已经确定,下面按照传统的kmeans算法执行
for loop=1:1:LOOP
DIS=zeros(K,N);  %计算N个样本与K个聚类之间的距离,行代表聚类,列代表样本
for i=1:1:N
    for j=1:1:K
        d=K_center(:,j)-data(:,i);
        DIS(j,i)=d'*d;
    end
end

class=zeros(1,N);
for i=1:1:N
    class(1,i)=1;
    d=DIS(1,i);
    for j=1:1:K
        if DIS(j,i)<d
            d=DIS(j,i);
            class(1,i)=j;
        end
    end
end

for k=1:1:K
    sum=zeros(dim,1);
    idx=find(class==k);
    count=length(idx);
    for c=1:1:count
        sum=sum+data(:,idx(c));
    end
K_center(:,k)=sum/count;
end

end
   


2.实例

 [data,w,m,s]=Gauss_sample(2,5,100);
figure(1);clf
plot(data(1,:),data(2,:),'b*');
 [K_center,seed]=kmeans_plus(data,5,100);


原始高斯分布的聚类中心 m=

-163   261    26  -115   157
  -196    97   -97    55  -169

经过kmeans++算法计算得到的聚类中心 K_center=

263.1232 -164.0305 -114.1521  157.6263   27.8879
   95.3258 -195.5066   55.2112 -168.8433  -95.0808

经过一一比对,效果还不错,下面用一幅图显示结果,蓝色星星代表数据,红色圆圈是kmeans++找到的聚类中心:






















猜你喜欢

转载自blog.csdn.net/cutelily2014/article/details/51814090