【Python实例第17讲】均值偏移聚类算法

机器学习训练营——机器学习爱好者的自由交流空间(qq 群号:696721295)

均值偏移(mean shift)是一个非参数特征空间分析技术,用来寻找密度函数的最大值点。它的应用领域包括聚类分析和图像处理等。

均值偏移算法

均值偏移是一个迭代地求密度函数极值点的方法。首先,从一个初始估计 x x 出发。这里要给定一个核函数 K ( x i x ) K(x_i-x) , 典型采用的是高斯核。核函数用来确定 x x 的邻近点的权,而这些邻近点用来重新计算均值。这样,在 x x 点的密度的加权均值

m ( x ) = x i N ( x ) K ( x i x ) x i x i N ( x ) K ( x i x ) m(x)=\dfrac{\sum_{x_i\in N(x)}K(x_i-x)x_i}{\sum_{x_i\in N(x)}K(x_i-x)}

其中, N ( x ) N(x) x i x_i 的邻居集。称

m ( x ) x m(x)-x
mean shift. 现在,升级 x x 的值为 m ( x ) m(x) , 重复这个估计过程,直到 m ( x ) m(x) 收敛。
以下是一个迭代过程的示意图。
在这里插入图片描述

聚类应用

均值偏移聚类的目的是发现来自平滑密度的样本团(‘blobs’). 它是一个基于质心的算法,当质心的改变很小时,将停止搜索。因此,它能够自动设置类数,这是与k-means聚类法的显著区别。当确定所有质心后,质心对应类。对于每一个样本点,将它归于距离最近的质心代表的类里。

A demo example

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# #############################################################################
# Generate sample data
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# #############################################################################
# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
    plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
             markeredgecolor='k', markersize=14)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()

number of estimated clusters : 3

在这里插入图片描述

阅读更多精彩内容,请关注微信公众号:统计学习与大数据

猜你喜欢

转载自blog.csdn.net/wong2016/article/details/84255245
今日推荐