k-means算法介绍及相关问题解决

k-means算法介绍

条件及约定:

  • 设待分类的模式特征矢量集为{x1,x2,…,xN};
  • 类的数目K是事先取定的。

基本思想:

  • 首先任意选取K个聚类中心,按最小距离原则将各模式分配到K类的某一类;
  • 不断计算聚类中心和调整各模式的类别,最终使各模式到其判属类别中心的距离平方之和最小。

准则函数:

  • 聚类集中每一样本点到该类中心的距离平方和。

算法步骤

优缺点分析

结果受到所选聚类中心的个数和其初始位置,以及模式样本的几何性质及读入次序等的影响。
实际应用中需要试探不同的K值和选择不同的聚类中心起始值。

避免初始聚类中心的影响

  • 多次运行K均值算法,例如50~1000次,每次随机选取不同的初始聚类中心。
  • 聚类结束后计算准则函数值。
  • 选取准则函数值最小的聚类结果为最后的结果。
  • 该方法一般适用于聚类数目小于10的情况。

示例

问题描述

试用K—均值法对如下模式分布进行聚类分析。编程实现,编程语言不限。

{x1(0, 0), x2(3,8), x3(2,2), x4(1,1), x5(5,3), x6(4,8), x7(6,3), x8(5,4),x9(6,4), x10(7,5)}

思路

解题步骤

描点绘图

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mlb


mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体

x_values = [0,3,2,1,5,4,6,5,6,7]
y_values = [0,8,2,1,3,8,3,4,4,5]



plt.scatter(x_values, y_values, s=100)


plt.xlabel('x1', fontsize=20)
plt.ylabel('x2', fontsize=20)
plt.title(u'坐标分布图')
#plt.axis([0,10,0,10])
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)

plt.show()

可知存在三个聚类,设定k=3

k-均值算法实现

import math
import numpy as np

#原始数据点
point = np.array([[0,0],
               [3,8],
               [2,2],
               [1,1],
               [5,3],
               [4,8],
               [6,3],
               [5,4],
               [6,4],
               [7,5]])

#选取k的值,初始化聚类中心
k = 3

z1 = point[0]
z2 = point[1]
z3 = point[2]

#计算距离进行聚类
while 1:
    temp1 = np.empty((0, 2))
    temp2 = np.empty((0, 2))
    temp3 = np.empty((0, 2))
    n1 = 0
    n2 = 0
    n3 = 0
    for i in range(10):
        p1 = point[i] - z1
        p2 = point[i] - z2
        p3 = point[i] - z3
        d1 = math.hypot(p1[0], p1[1])
        d2 = math.hypot(p2[0], p2[1])
        d3 = math.hypot(p3[0], p3[1])
        d = min(d1, d2, d3)
        if d == d1:
            n1 += 1
            temp1 = np.vstack((temp1, point[i]))
        elif d == d2:
            n2 += 1
            temp2 = np.vstack((temp2, point[i]))
        else:
            n3 += 1
            temp3 = np.vstack((temp3, point[i]))
    if ((z1 == sum(temp1) / n1).all()) and ((z2 == sum(temp2) / n2).all()) and ((z3 == sum(temp3) / n3).all()):
        break
    z1 = sum(temp1) / n1
    z2 = sum(temp2) / n2
    z3 = sum(temp3) / n3
print(z1)
print(temp1.shape)
print(temp1)
print(z2)
print(temp2.shape)
print(temp2)
print(z3)
print(temp3.shape)
print(temp3)

结果:

[1. 1.]
(3, 2)
[[0. 0.]
 [2. 2.]
 [1. 1.]]
[3.5 8. ]
(2, 2)
[[3. 8.]
 [4. 8.]]
[5.8 3.8]
(5, 2)
[[5. 3.]
 [6. 3.]
 [5. 4.]
 [6. 4.]
 [7. 5.]]

进程已结束,退出代码 0

结果分析:

有三个聚类。
第一个聚类中心点为(1,1),该聚类包括三个点,分别是(0,0),(2,2),(1,1);
第二个聚类中心点为(3.5,8),该聚类包括两个点,分别为(3,8),(4,8);
第三个聚类中心为(5.8,3.8),该聚类包括五个点,分别为(5,3),(6,3),(5,4),(6,4),(7,5);

选取不同的初始聚类中心,并计算准则函数值

选取不同的聚类中心,迭代50次

import math
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib as mlb

#原始数据点
point = np.array([[0,0],
               [3,8],
               [2,2],
               [1,1],
               [5,3],
               [4,8],
               [6,3],
               [5,4],
               [6,4],
               [7,5]])

#选取k的值,设置迭代次数初始值,初始化聚类中心
k=3

z = np.empty((k,2))

#初始聚类中心偏移量
num = 0

y_values = []

while num <=50:

    #生成三个不同的随机数
    r = []
    while len(r) < 3:
        x = random.randint(0,9)
        if x not in r:
            r.append(x)
    print("num %d, the random index is %d, %d, %d" %(num, r[0], r[1], r[2]))
    #利用生成的随机数初始化聚类中心
    for i in range(k):
        z[i] = point[r[i]]

    #计算距离进行聚类
    while 1:
        temp = []
        temp1 = np.empty((0, 2))
        temp2 = np.empty((0, 2))
        temp3 = np.empty((0, 2))
        n = [0] * k
        for i in range(10):
            p1 = point[i] - z[0]
            p2 = point[i] - z[1]
            p3 = point[i] - z[2]
            d1 = math.hypot(p1[0], p1[1])
            d2 = math.hypot(p2[0], p2[1])
            d3 = math.hypot(p3[0], p3[1])
            d = min(d1, d2, d3)
            if d == d1:
                n[0] += 1
                temp1 = np.vstack((temp1, point[i]))
            elif d == d2:
                n[1] += 1
                temp2 = np.vstack((temp2, point[i]))
            else:
                n[2] += 1
                temp3 = np.vstack((temp3, point[i]))
        temp.append(temp1)
        temp.append(temp2)
        temp.append(temp3)
        if ((z[0] == sum(temp[0]) / n[0]).all()) and ((z[1] == sum(temp[1]) / n[1]).all()) and ((z[2] == sum(temp[2]) / n[2]).all()):
            break
        for i in range(k):
            z[i] = sum(temp[i]) / n[i]

    print("the center is")
    print(z)
    
    #计算准则函数
    loss = 0
    for i in range(k):
        for j in range(n[i]):
            p = temp[i][j] - z[i]
            d = math.pow(math.hypot(p[0], p[1]), 2)
            loss += d
    print("the J is %f" %(loss))
    y_values.append(loss)
    num += 1


mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体

x_values = []
for i in range(num):
    x_values.append(i)


plt.plot(x_values, y_values, c='red')

plt.xlabel('迭代次数n', fontsize=20)
plt.ylabel('准则函数J', fontsize=20)
plt.title(u'准则函数折线图')
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)

plt.show()

比较准则函数值,选取最小的结果为最后结果

上一步代码输出结果为:

num 0, the random index is 2, 5, 0
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 1, the random index is 1, 4, 0
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 2, the random index is 2, 3, 0
the center is
[[5.14285714 5.        ]
 [1.5        1.5       ]
 [0.         0.        ]]
the J is 39.857143
num 3, the random index is 0, 9, 1
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 4, the random index is 0, 7, 8
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 5, the random index is 2, 3, 8
the center is
[[2.         2.        ]
 [0.5        0.5       ]
 [5.14285714 5.        ]]
the J is 39.857143
num 6, the random index is 2, 0, 3
the center is
[[5.14285714 5.        ]
 [0.         0.        ]
 [1.5        1.5       ]]
the J is 39.857143
num 7, the random index is 4, 2, 6
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 8, the random index is 2, 8, 9
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 9, the random index is 8, 5, 0
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 10, the random index is 0, 5, 9
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 11, the random index is 0, 1, 5
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 12, the random index is 5, 1, 4
the center is
[[4.   8.  ]
 [3.   8.  ]
 [4.   2.75]]
the J is 67.500000
num 13, the random index is 2, 9, 5
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 14, the random index is 2, 4, 9
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 15, the random index is 3, 1, 0
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 16, the random index is 7, 9, 1
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 17, the random index is 3, 1, 4
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 18, the random index is 2, 7, 4
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 19, the random index is 1, 9, 8
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 20, the random index is 7, 1, 4
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 21, the random index is 6, 2, 8
the center is
[[5.8 3.8]
 [1.  1. ]
 [3.5 8. ]]
the J is 10.100000
num 22, the random index is 9, 5, 7
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 23, the random index is 6, 9, 0
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 24, the random index is 4, 1, 8
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 25, the random index is 2, 5, 1
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 26, the random index is 1, 6, 7
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 27, the random index is 8, 2, 1
the center is
[[5.8 3.8]
 [1.  1. ]
 [3.5 8. ]]
the J is 10.100000
num 28, the random index is 9, 6, 4
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 29, the random index is 9, 5, 4
the center is
[[5.8 3.8]
 [3.5 8. ]
 [1.  1. ]]
the J is 10.100000
num 30, the random index is 3, 7, 2
the center is
[[0.5        0.5       ]
 [5.14285714 5.        ]
 [2.         2.        ]]
the J is 39.857143
num 31, the random index is 4, 9, 5
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 32, the random index is 1, 3, 4
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 33, the random index is 5, 2, 1
the center is
[[5.8 3.8]
 [1.  1. ]
 [3.5 8. ]]
the J is 10.100000
num 34, the random index is 1, 3, 6
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 35, the random index is 5, 6, 0
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 36, the random index is 2, 8, 7
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 37, the random index is 2, 1, 7
the center is
[[1.  1. ]
 [3.5 8. ]
 [5.8 3.8]]
the J is 10.100000
num 38, the random index is 1, 0, 9
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 39, the random index is 3, 2, 5
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 40, the random index is 8, 4, 0
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 41, the random index is 8, 3, 5
the center is
[[5.8 3.8]
 [1.  1. ]
 [3.5 8. ]]
the J is 10.100000
num 42, the random index is 5, 0, 9
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 43, the random index is 2, 7, 0
the center is
[[1.5        1.5       ]
 [5.14285714 5.        ]
 [0.         0.        ]]
the J is 39.857143
num 44, the random index is 1, 9, 6
the center is
[[3.5 8. ]
 [5.8 3.8]
 [1.  1. ]]
the J is 10.100000
num 45, the random index is 2, 4, 7
the center is
[[1.  1. ]
 [5.8 3.8]
 [3.5 8. ]]
the J is 10.100000
num 46, the random index is 1, 4, 8
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 47, the random index is 1, 3, 5
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 48, the random index is 9, 2, 4
the center is
[[3.5 8. ]
 [1.  1. ]
 [5.8 3.8]]
the J is 10.100000
num 49, the random index is 6, 4, 1
the center is
[[5.8 3.8]
 [1.  1. ]
 [3.5 8. ]]
the J is 10.100000
num 50, the random index is 3, 7, 2
the center is
[[0.5        0.5       ]
 [5.14285714 5.        ]
 [2.         2.        ]]
the J is 39.857143

进程已结束,退出代码 0

由以上结果比较可知,准则函数最小为10.1,聚类正确

再次绘图描点,确认是否为聚类中心

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mlb


mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体

x_values = [0,3,2,1,5,4,6,5,6,7]
y_values = [0,8,2,1,3,8,3,4,4,5]

cen_x = [1,3.5,5.8]
cen_y = [1,8,3.8]

plt.scatter(x_values, y_values, s=100)
plt.scatter(cen_x,cen_y,edgecolors='red')

plt.xlabel('x1', fontsize=20)
plt.ylabel('x2', fontsize=20)
plt.title(u'坐标分布图')
#plt.axis([0,10,0,10])
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)

plt.show()

很明显,聚类中心为如图所示的三点

发布了267 篇原创文章 · 获赞 51 · 访问量 25万+

猜你喜欢

转载自blog.csdn.net/AcSuccess/article/details/102621483
今日推荐