#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2018/3/25 14:07 # @Author : HJH # @Site : # @File : clauster.py # @Software: PyCharm import math import random class Clauster(object): def __init__(self,samples): if len(samples)==0: raise Exception('【错误】一个空的聚类') self.samples=samples self.n_dim=samples[0].n_dim for sample in samples: if sample.n_dim!= self.n_dim: raise Exception('【错误】聚类中样本点的个数不一致') self.centroid=self.cal_centroid() def __repr__(self): return str(self.samples) def update(self,samples): old_centroid=self.centroid self.samples=samples self.centroid=self.cal_centroid() shift=get_distance(old_centroid,self.centroid) return shift def cal_centroid(self): n_samples=len(self.samples) coords=[sample.coords for sample in self.samples] unzipped=zip(*coords) centroid_coords=[math.fsum(d_list)/n_samples for d_list in unzipped] return Sample(centroid_coords) class Sample(object): def __init__(self,coords): self.coords=coords self.n_dim=len(coords) def __repr__(self): return str(self.coords) def get_distance(a, b): if a.n_dim != b.n_dim: raise Exception('【错误】维度不同,不能计算距离') acc_diff = 0.0 for i in range(a.n_dim): square = pow((a.coords[i] - b.coords[i]), 2) acc_diff += square distance = math.sqrt(acc_diff) return distance def gen_random_sample(n_dim, lower, upper): sample = Sample([random.uniform(lower, upper) for _ in range(n_dim)]) return sample
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2018/3/24 21:41 # @Author : HJH # @Site : # @File : main.py # @Software: PyCharm import random import matplotlib.pyplot as plt from matplotlib import colors as mcolors from clauster import Clauster,get_distance,gen_random_sample def kmeans(samples,k,threshold): init_sample=random.sample(samples,k) # get_distance(1,2) clusters=[Clauster([sample]) for sample in init_sample] n_loop=0 while True: list=[[] for _ in clusters] n_loop+=1 for sample in samples: small_distance=get_distance(sample,clusters[0].centroid) cluster_index=0 for i in range(k-1): distance=get_distance(sample,clusters[i+1].centroid) if distance<small_distance: small_distance=distance cluster_index=i+1 list[cluster_index].append(sample) biggest_shift=0.0 for i in range(k): shift=clusters[i].update(list[i]) biggest_shift=max(biggest_shift,shift) if biggest_shift<threshold: print('第{}次迭代后,聚类稳定'.format(n_loop)) break return clusters def run_main(): n_samples=1000 # 特征维数 n_feature=2 # 特征数范围 lower=0 upper=200 #聚类个数 n_cluster=2 samples=[gen_random_sample(n_feature,lower,upper) for _ in range(n_samples)] # print(len(samples),"--------------") threshold=0.2 clusters=kmeans(samples,n_cluster,threshold) for i,c in enumerate(clusters): for sample in c.samples: print('聚类---{},样本点---{}'.format(i,sample)) plt.subplot() color_names=list(mcolors.cnames) for i,c in enumerate(clusters): x=[] y=[] random.choice color=[color_names[i]]*len(c.samples) for sample in c.samples: x.append(sample.coords[0]) y.append(sample.coords[1]) plt.scatter(x,y,c=color) plt.show() if __name__=='__main__': run_main()