EM算法和GMM算法到底是个怎么回事?(十连问)

0.前言

看过很多博客“详解EM和GMM算法”这之类的,也看过李航《统计学基础》上的EM和GMM,这里总结一下几个问题,如果以下问题可以解决,我觉得EM算法和GMM模型就理解的相当好了。本文章不做数学推导,主要从大体上理解算法和模型(先弄懂算法在干什么比一味的推理公式要好很多),文末会附加pdf里边有详细的数学推导,当然你也可以参考其他博客的推导。

1.E-M在解决什么问题?(一句话概括EM)

①E-M在解决数据缺失的参数估计问题,这里缺失的参数就是所谓的隐变量。
②E-M在近似实现对观测数据的极大似然估计(Maximum Likelihood Estimation)
③当MLE过程中存在 log-sum项,虽然很好求导,但是求解起来简直怀疑人生,尤其是当数据量很大,模型比较复杂时,想死的心都有。

2.E-M是一种模型么?

①Sorry,如果你理解E-M为一种模型,那你从一开始就错了,EM只是用来解决问题的算法,你也可以理解为一种框架。
②我举个栗子,现在假如有一件事A,一般我们用方法B去解决,虽然B有时候很管用,但是在A变得比较复杂的时候,B就不行了,此时一位精神小伙走过来对B说:兄弟,让一让,我要开始装x了,那么这个小伙子就是EM(A Method That Better Than B)

3.E-step和M-step分别指的是什么?

①E-step:Expectation,就是求期望,关于谁呢,关于隐变量后验分布的期望。
②M-step:maximazation最大化期望,求出参数
③E步和M步不断迭代,直至收敛
④其实更深一步讲,后验分布一般求不出来,这就是为什么有变分推断的原因(本文不做说明,可以忽略)

4.E-M的优化目标是什么?

从EM解决了什么问题出发,很容易知道EM的优化目标是观测数据的似然函数最大化,怎么最大化,有个Q函数的概念,不做解释。

5.E-M收敛么,为什么?

答案是肯定的,这一块不做解释,因为有很多数学公式可以推导出来,具体参考李航《统计学基础》。

6.怎么理解隐变量,对任意模型都可以引入隐变量么?(隐变量的合理性)

并不是所有模型都可以引入隐变量,这里有两个条件:
①:引入隐变量后一定要使问题求解变得简单,隐变量可以理解为一个辅助变量用来帮助我们解题
②隐变量引入后,一定要确保观测数据的边缘概率(Marginal Distribution)保持不变,这样才合理。
③另外再说一点,隐变量引入后,数据可看作是由隐变量生成的,即Z生成X,我们在下一个问题中说明一下

7.如何理解GMM中的隐变量

①每一个样本都会对应一个隐变量,这个隐变量是个离散型随机变量。
②我们假设有K个Gaussian分布混合,则某样本对应的隐变量表示该样本属于某类高斯的概率,很明显有隐变量有K个取值,先验概率就等于对应的权重α
③权重α是根据样本的比例来确定的

8.怎么理解GMM和E-M之间的关系?

这个问题很关键,可以说GMM(Guassian Mixture Model)应用了E-M框架来学习参数,就像问题二所描述的一样

9.GMM可以做什么事情?

GMM可以做聚类分析:我们学习到参数后,当有一个新样本来时,我们根据参数来做一个计算,计算出该样本属于每一个高斯分布的概率,选出最大的概率,对应的分布就是该样本所归属的类,属于Soft-Cluster

10.GMM和K-means之间有什么联系?

GMM和K-mean的思想很像,不同的是,K-means是Hard—Cluster,而GMM是Soft-Cluster,要说谁更好,没有绝对可言。

参考文献

徐亦达EM算法(提取码:5roi)

源码


```python
import numpy as np
from sklearn.mixture import GaussianMixture as GM
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from collections import Counter

# 获取数据
np.random.seed(0)
data = make_blobs(n_samples=[80,60,70,140],n_features=2,centers=[[1,1],[5,5],[6,1],[2,8]],cluster_std=1,center_box=(-10,10))
X = data[0] #(290,2)
Y = data[1] #(290,)
#建模
gm = GM(n_components=4,random_state=0) #创建高斯混合模型类的一个对象
gm.fit(X) #用数据拟合模型
#预测&评估
prediction = gm.predict(X) #对训练集进行预测 (0-purple,1-blue,2-green,3-red)
accuracy =1- (6.0/len(prediction)) # 经过验证,只有6个样本被误分类
print('accuracy: ',accuracy)
  • 原始数据(Data without label)
    在这里插入图片描述
  • GMM聚类后
    在这里插入图片描述
发布了6 篇原创文章 · 获赞 8 · 访问量 413

猜你喜欢

转载自blog.csdn.net/weixin_44441131/article/details/104377623
今日推荐