k-means算法介绍
条件及约定:
- 设待分类的模式特征矢量集为{x1,x2,…,xN};
- 类的数目K是事先取定的。
基本思想:
- 首先任意选取K个聚类中心,按最小距离原则将各模式分配到K类的某一类;
- 不断计算聚类中心和调整各模式的类别,最终使各模式到其判属类别中心的距离平方之和最小。
准则函数:
- 聚类集中每一样本点到该类中心的距离平方和。
算法步骤
优缺点分析
结果受到所选聚类中心的个数和其初始位置,以及模式样本的几何性质及读入次序等的影响。
实际应用中需要试探不同的K值和选择不同的聚类中心起始值。
避免初始聚类中心的影响
- 多次运行K均值算法,例如50~1000次,每次随机选取不同的初始聚类中心。
- 聚类结束后计算准则函数值。
- 选取准则函数值最小的聚类结果为最后的结果。
- 该方法一般适用于聚类数目小于10的情况。
示例
问题描述
试用K—均值法对如下模式分布进行聚类分析。编程实现,编程语言不限。
{x1(0, 0), x2(3,8), x3(2,2), x4(1,1), x5(5,3), x6(4,8), x7(6,3), x8(5,4),x9(6,4), x10(7,5)}
思路
- 首先使用matplotlib描点绘图,判定存在几个聚类
- 确定聚类个数后,利用python实现k-means算法
- 多次运行k-均值算法,选取不同的初始聚类中心,计算准则函数值
- 比较准则函数值,选取其值最小的结果为最后的结果
- 根据上一步实现算法结果,返回第一步绘图确认结果是否正确
解题步骤
描点绘图
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mlb
mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体
x_values = [0,3,2,1,5,4,6,5,6,7]
y_values = [0,8,2,1,3,8,3,4,4,5]
plt.scatter(x_values, y_values, s=100)
plt.xlabel('x1', fontsize=20)
plt.ylabel('x2', fontsize=20)
plt.title(u'坐标分布图')
#plt.axis([0,10,0,10])
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)
plt.show()
可知存在三个聚类,设定k=3
k-均值算法实现
import math
import numpy as np
#原始数据点
point = np.array([[0,0],
[3,8],
[2,2],
[1,1],
[5,3],
[4,8],
[6,3],
[5,4],
[6,4],
[7,5]])
#选取k的值,初始化聚类中心
k = 3
z1 = point[0]
z2 = point[1]
z3 = point[2]
#计算距离进行聚类
while 1:
temp1 = np.empty((0, 2))
temp2 = np.empty((0, 2))
temp3 = np.empty((0, 2))
n1 = 0
n2 = 0
n3 = 0
for i in range(10):
p1 = point[i] - z1
p2 = point[i] - z2
p3 = point[i] - z3
d1 = math.hypot(p1[0], p1[1])
d2 = math.hypot(p2[0], p2[1])
d3 = math.hypot(p3[0], p3[1])
d = min(d1, d2, d3)
if d == d1:
n1 += 1
temp1 = np.vstack((temp1, point[i]))
elif d == d2:
n2 += 1
temp2 = np.vstack((temp2, point[i]))
else:
n3 += 1
temp3 = np.vstack((temp3, point[i]))
if ((z1 == sum(temp1) / n1).all()) and ((z2 == sum(temp2) / n2).all()) and ((z3 == sum(temp3) / n3).all()):
break
z1 = sum(temp1) / n1
z2 = sum(temp2) / n2
z3 = sum(temp3) / n3
print(z1)
print(temp1.shape)
print(temp1)
print(z2)
print(temp2.shape)
print(temp2)
print(z3)
print(temp3.shape)
print(temp3)
结果:
[1. 1.]
(3, 2)
[[0. 0.]
[2. 2.]
[1. 1.]]
[3.5 8. ]
(2, 2)
[[3. 8.]
[4. 8.]]
[5.8 3.8]
(5, 2)
[[5. 3.]
[6. 3.]
[5. 4.]
[6. 4.]
[7. 5.]]
进程已结束,退出代码 0
结果分析:
有三个聚类。
第一个聚类中心点为(1,1),该聚类包括三个点,分别是(0,0),(2,2),(1,1);
第二个聚类中心点为(3.5,8),该聚类包括两个点,分别为(3,8),(4,8);
第三个聚类中心为(5.8,3.8),该聚类包括五个点,分别为(5,3),(6,3),(5,4),(6,4),(7,5);
选取不同的初始聚类中心,并计算准则函数值
选取不同的聚类中心,迭代50次
import math
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib as mlb
#原始数据点
point = np.array([[0,0],
[3,8],
[2,2],
[1,1],
[5,3],
[4,8],
[6,3],
[5,4],
[6,4],
[7,5]])
#选取k的值,设置迭代次数初始值,初始化聚类中心
k=3
z = np.empty((k,2))
#初始聚类中心偏移量
num = 0
y_values = []
while num <=50:
#生成三个不同的随机数
r = []
while len(r) < 3:
x = random.randint(0,9)
if x not in r:
r.append(x)
print("num %d, the random index is %d, %d, %d" %(num, r[0], r[1], r[2]))
#利用生成的随机数初始化聚类中心
for i in range(k):
z[i] = point[r[i]]
#计算距离进行聚类
while 1:
temp = []
temp1 = np.empty((0, 2))
temp2 = np.empty((0, 2))
temp3 = np.empty((0, 2))
n = [0] * k
for i in range(10):
p1 = point[i] - z[0]
p2 = point[i] - z[1]
p3 = point[i] - z[2]
d1 = math.hypot(p1[0], p1[1])
d2 = math.hypot(p2[0], p2[1])
d3 = math.hypot(p3[0], p3[1])
d = min(d1, d2, d3)
if d == d1:
n[0] += 1
temp1 = np.vstack((temp1, point[i]))
elif d == d2:
n[1] += 1
temp2 = np.vstack((temp2, point[i]))
else:
n[2] += 1
temp3 = np.vstack((temp3, point[i]))
temp.append(temp1)
temp.append(temp2)
temp.append(temp3)
if ((z[0] == sum(temp[0]) / n[0]).all()) and ((z[1] == sum(temp[1]) / n[1]).all()) and ((z[2] == sum(temp[2]) / n[2]).all()):
break
for i in range(k):
z[i] = sum(temp[i]) / n[i]
print("the center is")
print(z)
#计算准则函数
loss = 0
for i in range(k):
for j in range(n[i]):
p = temp[i][j] - z[i]
d = math.pow(math.hypot(p[0], p[1]), 2)
loss += d
print("the J is %f" %(loss))
y_values.append(loss)
num += 1
mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体
x_values = []
for i in range(num):
x_values.append(i)
plt.plot(x_values, y_values, c='red')
plt.xlabel('迭代次数n', fontsize=20)
plt.ylabel('准则函数J', fontsize=20)
plt.title(u'准则函数折线图')
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)
plt.show()
比较准则函数值,选取最小的结果为最后结果
上一步代码输出结果为:
num 0, the random index is 2, 5, 0
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 1, the random index is 1, 4, 0
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 2, the random index is 2, 3, 0
the center is
[[5.14285714 5. ]
[1.5 1.5 ]
[0. 0. ]]
the J is 39.857143
num 3, the random index is 0, 9, 1
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 4, the random index is 0, 7, 8
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 5, the random index is 2, 3, 8
the center is
[[2. 2. ]
[0.5 0.5 ]
[5.14285714 5. ]]
the J is 39.857143
num 6, the random index is 2, 0, 3
the center is
[[5.14285714 5. ]
[0. 0. ]
[1.5 1.5 ]]
the J is 39.857143
num 7, the random index is 4, 2, 6
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 8, the random index is 2, 8, 9
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 9, the random index is 8, 5, 0
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 10, the random index is 0, 5, 9
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 11, the random index is 0, 1, 5
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 12, the random index is 5, 1, 4
the center is
[[4. 8. ]
[3. 8. ]
[4. 2.75]]
the J is 67.500000
num 13, the random index is 2, 9, 5
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 14, the random index is 2, 4, 9
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 15, the random index is 3, 1, 0
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 16, the random index is 7, 9, 1
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 17, the random index is 3, 1, 4
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 18, the random index is 2, 7, 4
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 19, the random index is 1, 9, 8
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 20, the random index is 7, 1, 4
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 21, the random index is 6, 2, 8
the center is
[[5.8 3.8]
[1. 1. ]
[3.5 8. ]]
the J is 10.100000
num 22, the random index is 9, 5, 7
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 23, the random index is 6, 9, 0
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 24, the random index is 4, 1, 8
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 25, the random index is 2, 5, 1
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 26, the random index is 1, 6, 7
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 27, the random index is 8, 2, 1
the center is
[[5.8 3.8]
[1. 1. ]
[3.5 8. ]]
the J is 10.100000
num 28, the random index is 9, 6, 4
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 29, the random index is 9, 5, 4
the center is
[[5.8 3.8]
[3.5 8. ]
[1. 1. ]]
the J is 10.100000
num 30, the random index is 3, 7, 2
the center is
[[0.5 0.5 ]
[5.14285714 5. ]
[2. 2. ]]
the J is 39.857143
num 31, the random index is 4, 9, 5
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 32, the random index is 1, 3, 4
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 33, the random index is 5, 2, 1
the center is
[[5.8 3.8]
[1. 1. ]
[3.5 8. ]]
the J is 10.100000
num 34, the random index is 1, 3, 6
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 35, the random index is 5, 6, 0
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 36, the random index is 2, 8, 7
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 37, the random index is 2, 1, 7
the center is
[[1. 1. ]
[3.5 8. ]
[5.8 3.8]]
the J is 10.100000
num 38, the random index is 1, 0, 9
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 39, the random index is 3, 2, 5
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 40, the random index is 8, 4, 0
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 41, the random index is 8, 3, 5
the center is
[[5.8 3.8]
[1. 1. ]
[3.5 8. ]]
the J is 10.100000
num 42, the random index is 5, 0, 9
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 43, the random index is 2, 7, 0
the center is
[[1.5 1.5 ]
[5.14285714 5. ]
[0. 0. ]]
the J is 39.857143
num 44, the random index is 1, 9, 6
the center is
[[3.5 8. ]
[5.8 3.8]
[1. 1. ]]
the J is 10.100000
num 45, the random index is 2, 4, 7
the center is
[[1. 1. ]
[5.8 3.8]
[3.5 8. ]]
the J is 10.100000
num 46, the random index is 1, 4, 8
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 47, the random index is 1, 3, 5
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 48, the random index is 9, 2, 4
the center is
[[3.5 8. ]
[1. 1. ]
[5.8 3.8]]
the J is 10.100000
num 49, the random index is 6, 4, 1
the center is
[[5.8 3.8]
[1. 1. ]
[3.5 8. ]]
the J is 10.100000
num 50, the random index is 3, 7, 2
the center is
[[0.5 0.5 ]
[5.14285714 5. ]
[2. 2. ]]
the J is 39.857143
进程已结束,退出代码 0
由以上结果比较可知,准则函数最小为10.1,聚类正确
再次绘图描点,确认是否为聚类中心
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mlb
mlb.rcParams['font.sans-serif']=['SimHei'] #指定默认字体 SimHei为黑体
x_values = [0,3,2,1,5,4,6,5,6,7]
y_values = [0,8,2,1,3,8,3,4,4,5]
cen_x = [1,3.5,5.8]
cen_y = [1,8,3.8]
plt.scatter(x_values, y_values, s=100)
plt.scatter(cen_x,cen_y,edgecolors='red')
plt.xlabel('x1', fontsize=20)
plt.ylabel('x2', fontsize=20)
plt.title(u'坐标分布图')
#plt.axis([0,10,0,10])
plt.grid(True)#设置网格线
plt.tick_params(axis='both', which='major', labelsize=14)
plt.show()
很明显,聚类中心为如图所示的三点