用python实现K-Means均值聚类算法

以下代码仿照《机器学习实战》一书,在其基础上稍加修改:

import numpy as np
import matplotlib.pyplot as plt
def loadDataSet(filename):
     dataMat=[]
     fr=open(filename)
     for line in fr.readlines():
          curLine=line.strip().split('\t')
          fltLine=list(map(float,curLine))
          dataMat.append(fltLine)
     return dataMat

def distEclud(vecA,vecB):   #计算两个向量之间的欧氏距离
     return np.sqrt(np.sum(np.power(vecA-vecB,2)))

def randCent(dataSet,k):  #随机产生均值向量
     n=np.shape(dataSet)[1]
     centroids=np.mat(np.zeros((k,n)))
     for j in range(n):
          minJ=min(dataSet[:,j])
          rangeJ=float(max(dataSet[:,j])-minJ)
          centroids[:,j]=minJ+rangeJ*np.random.rand(k,1)  #防止均值向量在簇的边界
     return centroids

def kMeans(dataSet,k,distMeas=distEclud,createCent=randCent):
     m=np.shape(dataSet)[0]
     clusterAssment=np.mat(np.zeros((m,2)))
     centroids=createCent(dataSet,k)
     clusterChanged=True
     while clusterChanged:
          clusterChanged=False
          for i in range(m):
               minDist=float('inf')
               minIndex=-1
               for j in range(k):
                    distJI=distMeas(centroids[j,:],dataSet[i,:])
                    if distJI<minDist:
                         minDist=distJI
                         minIndex=j
               if clusterAssment[i,0]!=minIndex:
                    clusterChanged=True
                    clusterAssment[i,:]=minIndex,minDist**2
          for cent in range(k):
               ptsInClust=dataSet[np.nonzero(clusterAssment[:,0].A==cent)[0]]
               centroids[cent,:]=np.mean(ptsInClust,axis=0)
     return centroids,clusterAssment

if __name__=="__main__":
     dataSet=np.mat(loadDataSet("G:\\IDLE\\testSet.txt"))
     centroids,clusterAssment=kMeans(dataSet,4,distMeas=distEclud,createCent=randCent)
     plt.figure()
     for each in dataSet.tolist():
          plt.scatter(each[0],each[1],s=50,c='blueviolet')
     for vector in centroids.tolist():
          plt.scatter(vector[0],vector[1],c='k',s=200,marker='+')
     plt.title('K-Means')
     plt.grid()
     plt.show()

在这里插入图片描述

发布了81 篇原创文章 · 获赞 22 · 访问量 7667

猜你喜欢

转载自blog.csdn.net/qq_38883271/article/details/103935618