一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第2天,点击查看活动详情。
一、实验算法设计
- 读取西瓜数据集
- 随机选取k个样本作为初始聚类中心
- 计算每个样本到各个聚类中心之间的距离,将每个样本分配给距离它最近的聚类中心,此时全部样本已划分为 k 组
- 更新聚类中心,将每组中样本的均值作为该组新的聚类中心
- 重复进行第二、三步,直到聚类中心趋于稳定,或者到达最大迭代次数。
实验分析
在西瓜数据集上使用KNN分类
对西瓜数据集进行简单的分析,结果如下:
特征 | 数据特征 | 角色 |
---|---|---|
编号 | 离散 | 编号 |
密度 | 连续 | 特征 |
含糖率 | 连续 | 特征 |
好瓜 | 离散 | 标签 |
因此,选择特征密度和含糖率进行聚类分析。
二、K均值聚类分析核心代码
-
导入所需库
import matplotlib.pyplot as plt import numpy as np import pandas as pd 复制代码
在本次实验中,我选择pandas1作为读取数据集的主要工具,选择numpy2加速主要的数学运算,选择matplotlib3进行数据可视化分析。
-
定义K均值聚类 KMeans
-
定义
__init__()
以初始化分类器class KNNClassifier: def __init__(self, x: pd.DataFrame): self.x = x ... 复制代码
其中
X
代表数据集特征。 -
预定义距离函数
distanceAll()
def distanceAll(center, rest): distances = np.apply_along_axis(_distances, 1, rest, center) return distances.sum() def _distances(point: np.ndarray, centers: np.ndarray): distances = np.apply_along_axis(_distance, 1, centers, point) return distances def _distance(x, y): return np.sqrt(np.dot(x, x) - 2 * np.dot(x, y) + np.dot(y, y)) 复制代码
此处我进行了多处优化,具体优化点如下:
避免使用
for-loop
以加快运行速度在第一个函数
distanceAll
中,传入的center
和rest
为多维矩阵,此处实现了center
和rest
之间互相两两求距离函数,且未使用任何for
循环,极大提升了运行速度。复用
_distance(x, y)
计算结果欧氏距离一般计算公式为:
但是我在此处使用的公式为其展开形式
此公式中红色部分在计算欧氏距离时会多次使用,因此,使用此公式可以充分利用numpy的缓存机制,减少不必要的重复运算量。
-
预定义
allocate()
核心方法为每个点找到最近的聚类中心def allocateAll(center, rest): # 2. 计算每个样本到各个聚类中心之间的距离,将每个样本分配给距离它最近的聚类中心 allocates = np.apply_along_axis(_allocate, 1, rest, center) # sns.scatterplot(data=rest, x=0, y=1, hue=allocates) copied = rest.copy() copied["allocations"] = allocates groups = copied.groupby("allocations").groups # 绘图 fig = plt.figure() ax = rest.plot.scatter(x=0, y=1, c=allocates, colormap='viridis', legend=True) center.iloc[list(groups.keys())].plot.scatter(x=0, y=1, c=list(groups.keys()), marker="x", colormap='viridis', s=200, ax=ax) plt.show() return groups def _allocate(point: np.ndarray, centers: np.ndarray): distances = np.apply_along_axis(_distance, 1, centers, point, "euclidean") nearest_center = np.argmin(distances) return nearest_center 复制代码
同时,在对每个点寻找中心进行聚类的过程中,还集合了绘图可视化方法。此处的可视化方法将绘制出之后聚类的过程。
-
定义
train()
在训练集上进行迭代训练class KMeans: ... def train(self, k): print(f" === k = {k} === ") batch = self.x.shape[0] features = self.x.shape[1] # 1. 随机选取 k 个样本作为初始的聚类中心 index = np.random.randint(0, batch, size=k) centers: pd.DataFrame = self.x.iloc[index] # 聚类中心 # rest: pd.DataFrame = self.x.loc[~self.x.index.isin(index)] allocations = allocateAll(centers, self.x) for i in range(10): last_centers = centers centers = np.empty((k, 2)) for label, points in allocations.items(): center = self.x.iloc[points] new_center = np.average(center, axis=0) centers[label] = new_center if np.isclose(last_centers, centers).all(): print(f"k = {k} 收敛,停止!") return distanceAll(pd.DataFrame(centers), self.x) allocations = allocateAll(pd.DataFrame(centers), self.x) 复制代码
在本段代码中,我指定每次训练最多进行
10
轮,一般来说,只需要迭代5次即可收敛到聚类中心。代码分为两部分,第一次的聚类中心在样本中随机选取,进行第一次聚类之后,再依据上一次的聚类结果,选择每一类的均值点作为中心进行循环迭代,当下一轮迭代的循环中心与上一轮相差不大时,终止迭代,返回此时的wss距离值。
-
三、实验数据及结果分析
在西瓜数据集上使用K均值聚类
-
导入所需库
import matplotlib.pyplot as plt import pandas as pd from model import KMeans 复制代码
此处导入刚刚编写的
KMeans
以及绘图工具matplotlib
进行wss曲线的绘制。 -
读取数据集并构建模型
df = pd.read_csv("kmeansdata.csv") model = KMeans(df[["m", "h"]]) 复制代码
此处读入西瓜数据集,并选定特征
m
和h
构建模型。 -
KMeans 模型训练,可视化,WSS曲线分析
wss = [] for i in range(2, 10): wss.append(model.train(k=i)) plt.plot(range(2, 10), wss) plt.savefig("result.png") 复制代码
此处我在2到15中选择
k
值,分别使用这些k
值在KMeans模型上进行训练,并保存每一次训练之后返回的wss距离,最后对wss距离进行可视化分析。训练过程可视化
k=3
首先,在数据集中随机选取三个样本作为聚类中心:
可以看出,选择的聚类中心偏下,然后进行第一次迭代:
在每一类中,选择其中心点作为下一次聚类中心,然后对每个点重新决定其类别,并进行下一次迭代:
可以看出,此时中心往中间偏移,分类更加合理。再进行一次迭代:
此后迭代中心不再产生明显变化,代表聚类中心收敛,本轮聚类结束。
WSS曲线可视化
四、总结及心得体会
- 在简单的数据集(如西瓜数据集)上,聚类效果较好,在几次迭代内便可达到收敛。
- 根据对不同k值的可视化分析,可以发现,在k=3时达到"肘部",此时K为最优值,大于3的k值会因为类别过多而失去统计意义,k值太小会导致类别过少,使类内距离急剧上升。
- 使用C接口实现Python程序比使用
Python-based-coding
效率更高。 - 掌握了一些简单的数据可视化方法,学会使用一些简单的matplotlib库中有关pyplot的函数,利用简单的数据可视化方法将大量的数据转化成图片,极大地简化了我们对结果数据的分析和比对,能够更轻易的获得一些结果上的规律和结论。
五、对本实验过程及方法、手段的改进建议
- 数据集可视化时,对高维特征粗暴选取前两个维度进行可视化分析会丢失其他维度的特征信息,此处可以选择降维方法,例如PCA4等,把高维特征投影到二维平面上以进行可视化分析。
- 可以尝试更加复杂的数据集。
- 可以尝试考量更多距离函数。