机器学习-二分KMeans

机器学习-二分KMeans

由于传统的KMeans算法的聚类结果容易受到初始聚类中心点选择的影响,因此在传统的KMeans算法的基础上进行算法改进,对初始中心点选取比较严格,各中心点的距离较远,这就避免了初始聚类中心会选到一个类上,一定程度上克服了算法限入局部最优状态。

二分KMeans(Bisecting KMeans)算法的主要思想是:首先将所有点作为一个簇,然后将该簇一分为二。之后选择能最大限度降低聚类代价函数(误差平方和)的簇划分为两个簇。以此进行下去,直到簇的数目等于用户给定的数目k为止。以上隐含的一个原则是:因为聚类的误差平方和能够衡量聚类性能,该值越小表示数据点越接近于它们的质心,聚类效果就越好。所以我们需要对误差平方和最大的簇进行再一次划分,因为误差平方和越大,表示该簇聚类效果越不好,越有可能是多个簇被当成了一个簇,所以我们首先需要对这个簇进行划分。

代码实现(其中的kmeans方法参照上一篇)

# 二分kmeans实现

from numpy import *
import matplotlib.pyplot as plt

# 数据文件转矩阵
from sklearn.cluster import KMeans

from com.machineLearning.kmeans.KmeansSelf import kMeans


def file2matrix(filePath):
    dataSet = []
    fr = open(filePath, 'r')
    content = fr.read()
    for line in content.splitlines():
        dataSet.append([line.split('\t')[0], line.split('\t')[1]])
    fr.close()
    return dataSet

# 欧氏距离公式
def distEclud(vecA, vecB):
    return linalg.norm(vecA.astype(float) - vecB.astype(float))

# 根据聚类中心绘制散点图,以及绘制聚类中心
def color_cluster(dataindx, dataSet, plt, k=4):
    # print(dataindx)
    plt.scatter(dataSet[:, 0].tolist(), dataSet[:, 1].tolist(), c='blue', marker='o')
    # index = 0
    # datalen = len(dataindx)
    # for indx in range(datalen):
    #     if int(dataindx[index]) == 0:
    #         plt.scatter(dataSet[index, 0].tolist(),dataSet[index, 1].tolist(), c='blue', marker='o')
    #     elif int(dataindx[index]) == 1:
    #         plt.scatter(dataSet[index, 0].tolist(), dataSet[index, 1].tolist(), c='green', marker='o')
    #     elif int(dataindx[index]) == 2:
    #         plt.scatter(dataSet[index, 0].tolist(), dataSet[index, 1].tolist(), c='red', marker='o')
    #     elif int(dataindx[index]) == 3:
    #         plt.scatter(dataSet[index, 0].tolist(), dataSet[index, 1].tolist(), c='cyan', marker='o')
    #     index += 1

# 绘制散点图
def drawScatter(plt, mydata, size=20, color='blue', mrkr='o'):
    # print(mydata.T[0][0].tolist())
    # print(mydata.T[1][0].tolist())
    plt.scatter(mydata.T[0][0].tolist(), mydata.T[1][0].tolist(), s=size, c=color, marker=mrkr)

dataMat = file2matrix("/Users/FengZhen/Desktop/accumulate/机器学习/推荐系统/kmeans聚类测试集.txt")  # 从文件构建的数据集
dataSet = mat(dataMat).astype(float)   # 转换为矩阵形式
k = 4   # 分类数
m = dataSet.shape[0]    # 行数
# axis 不设置值,对 m*n 个数求均值,返回一个实数
# axis = 0:压缩行,对各列求均值,返回 1* n 矩阵
# axis =1 :压缩列,对各行求均值,返回 m *1 矩阵
centroid0 = mean(dataSet, axis=0)[0]   # 初始化第一个聚类中心:每一列的均值
centList = [centroid0]  # 把均值聚类中心加入中心表中
ClustDist = mat(zeros((m, 2)))  # 初始化聚类距离表,距离方差,每组数据到中心点的距离
for j in range(m):
    ClustDist[j, 1] = distEclud(centroid0, dataSet[j, :])**2


while(len(centList) < k):
    lowestSSE = inf    # 初始化最小误差平方和。核心参数,这个值越小说明聚类的效果越好
    for i in range(len(centList)):
        print(i)
        ptsInCurrCluster = dataSet[nonzero(ClustDist[:, 0].A == i)[0], :]   # .A为转换数组
        # centroidMat:中心点, splitClustAss:所属分类,距离
        centroidMat, splitClustAss = kMeans(ptsInCurrCluster, 2)
        # 计算所有距离和
        sseSplit = sum(splitClustAss[:, 1])     # 计算splitClustAss的距离平方和
        sseNotSplit = sum(ClustDist[nonzero(ClustDist[:, 0].A != i)[0], 1]) # 计算ClustDist第一列不等于i的距离平方和
        if(sseSplit + sseNotSplit) < lowestSSE:
            bestCentToSplit = i     # 确定聚类中心的最优分隔点
            bestNewCents = centroidMat  # 用新的聚类中心更新最优聚类中心
            bestClustAss = splitClustAss.copy() # 用深拷贝聚类距离表为最优聚类距离表
            lowestSSE = sseSplit + sseNotSplit  # 更新lowestSSE
    # bestClustAss 赋值为聚类中心的索引
    # 第一部分:bestClustAss[bIndx0,0]赋值为聚类中心的索引
    bestClustAss[nonzero(bestClustAss[:, 0].A == 1)[0], 0] = len(centList)
    # 用最优分隔点的指定聚类中心索引
    bestClustAss[nonzero(bestClustAss[:, 0].A == 0)[0], 0] = bestCentToSplit

    # 覆盖:bestNewCents[0, :].tolist()[0]附加到原有聚类中心的bestCentToSplit位置
    centList[bestCentToSplit] = bestNewCents[0, :].tolist()[0]
    # 增加:聚类中心增加一个新的bestNewCents[1, :].tolist()[0]向量
    centList.append(bestNewCents[1, :].tolist()[0])
    ClustDist[nonzero(ClustDist[:, 0].A == bestCentToSplit)[0], :] = bestClustAss

color_cluster(ClustDist[:, 0:1], dataSet, plt)
# print("centList:", mat(centList))
drawScatter(plt, mat(centList), size=60, color='red', mrkr='D')
plt.show()

效果如下

猜你喜欢

转载自www.cnblogs.com/EnzoDin/p/12528003.html