# 1.EM算法简介

EM算法的详解和样本集实例数学过程讲解，可以详见：https://blog.csdn.net/u012421852/article/details/79915908

# 2.EM算法的Python实现

```# -*- coding: utf-8 -*-
"""
@author: 蔚蓝的天空Tom
Aim:实现EM算法(Expectation Maximization Algorithm)
"""

import numpy as np

class CEM(object):
def __init__(self, samples, pa, pb, threshold):
self.samples = samples
self.pa = pa
self.pb = pb
self.eStepRet = None
self.mStepRet = None
self.threshold = threshold

self.work()

def likelihood_func(self, samples, p):
'''似然函数'''
ret = []
for e in samples:
ret.append(np.power(p, list(e).count(1))*np.power(1-p, list(e).count(0)))
return ret

def e_step(self):
'''
计算在模型参数pa,pb下观察数据来自投掷硬币a/b的概率
'''
#计算每轮投掷coin a和coin b的似然函数值(即每个样本发生概率的似然值)
likelihooda = self.likelihood_func(self.samples, self.pa)
#[0.00079626239999999997, 0.0005308416000000002, 0.0005308416000000002, 0.0005308416000000002, 0.0011943936]
likelihoodb = self.likelihood_func(self.samples, self.pb)
#[0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625]

#计算每轮投掷来自coin a和coin b的概率
self.eStepRet = np.array([e/sum(e) for e in zip(likelihooda, likelihoodb)])
print('eStepRet:\n', self.eStepRet)
#[[ 0.44914893  0.55085107]
# [ 0.35215613  0.64784387]
# [ 0.35215613  0.64784387]
# [ 0.35215613  0.64784387]
# [ 0.55016939  0.44983061]]

return

def m_step(self):
'''计算模型参数pa, pb的新估计值
'''
old_pa, old_pb = self.pa, self.pb
print('old pa:', old_pa, 'old pb:', old_pb)
h_a, t_a = 0, 0
h_b, t_b = 0, 0
for sample, e in zip(self.samples, self.eStepRet):
h_a += list(sample).count(1) * e[0]
t_a += list(sample).count(0) * e[0]
h_b += list(sample).count(1) * e[1]
t_b += list(sample).count(0) * e[1]

self.pa = h_a / (h_a + t_a)
self.pb = h_b / (h_b + t_b)
print('new pa:', self.pa, 'new pb:', self.pb)
gap_pa, gap_pb = self.pa - old_pa, self.pb - old_pb
print('gap_pa:', gap_pa, 'gap_pb:', gap_pb)

return gap_pa < self.threshold and gap_pb < self.threshold

def work(self):
self.e_step()
stop = self.m_step()
if (stop != True):
return self.work()
print('stop em\n')
return

def GetResult(self):
return self.pa, self.pb

pass

def CEM_manual():
samples = np.array([[1,0,1,0,1,0,1,0,1,0], #coin a, 5+5-
[1,0,1,0,1,0,1,0,1,1], #coin b, 6+4-
[1,1,1,0,1,0,1,0,1,0], #coin a, 6+4-
[1,0,1,1,1,0,1,0,1,0], #coin b, 6+4-
[1,0,1,0,1,0,0,1,0,0]])#coin a, 4+6-

samples = np.array([[1,0,1,1,1,0,1,0,1,0], #coin a, 5+5-
[1,0,1,1,1,0,1,0,1,1], #coin b, 6+4-
[1,1,1,1,1,0,1,0,1,0], #coin a, 6+4-
[1,0,1,1,1,1,1,0,1,0], #coin b, 6+4-
[1,0,1,1,1,0,0,1,0,0]])#coin a, 4+6-
#可以知道
#p(1|a) = (5+6+4)/30 = 0.5
#p(1|b) = (6+6)/20 = 0.6

#设置初始值
pa, pb = 0.4, 0.5#p(1|a) = 0.4, p(1|b) = 0.5
threshold = 0.00001
em = CEM(samples, pa, pb, threshold)
ret = em.GetResult()
print(ret)

return

if __name__=='__main__':
CEM_manual()
```

# 3.运行结果

```runfile('C:/Users/l13277/EM.py', wdir='C:/Users/l13277')
eStepRet:
[[ 0.35215613  0.64784387]
[ 0.26599464  0.73400536]
[ 0.26599464  0.73400536]
[ 0.26599464  0.73400536]
[ 0.44914893  0.55085107]]
old pa: 0.4 old pb: 0.5
new pa: 0.621811879589 new pb: 0.648553523107
gap_pa: 0.221811879589 gap_pb: 0.148553523107
eStepRet:
[[ 0.51017251  0.48982749]
[ 0.4813223   0.5186777 ]
[ 0.4813223   0.5186777 ]
[ 0.4813223   0.5186777 ]
[ 0.53895512  0.46104488]]
old pa: 0.621811879589 old pb: 0.648553523107
new pa: 0.63630074059 new pb: 0.643678879642
gap_pa: 0.0144888610009 gap_pb: -0.00487464346478
eStepRet:
[[ 0.50320194  0.49679806]
[ 0.49519623  0.50480377]
[ 0.49519623  0.50480377]
[ 0.49519623  0.50480377]
[ 0.51120602  0.48879398]]
old pa: 0.63630074059 old pb: 0.643678879642
new pa: 0.638975359185 new pb: 0.64102463807
gap_pa: 0.00267461859504 gap_pb: -0.00265424157212
eStepRet:
[[ 0.50088945  0.49911055]
[ 0.49866584  0.50133416]
[ 0.49866584  0.50133416]
[ 0.49866584  0.50133416]
[ 0.50311303  0.49688697]]
old pa: 0.638975359185 old pb: 0.64102463807
new pa: 0.639715379851 new pb: 0.640284620153
gap_pa: 0.000740020666399 gap_pb: -0.00074001791737
eStepRet:
[[ 0.50024707  0.49975293]
[ 0.4996294   0.5003706 ]
[ 0.4996294   0.5003706 ]
[ 0.4996294   0.5003706 ]
[ 0.50086473  0.49913527]]
old pa: 0.639715379851 old pb: 0.640284620153
new pa: 0.639920938888 new pb: 0.640079061112
gap_pa: 0.000205559037088 gap_pb: -0.000205559040588
eStepRet:
[[ 0.50006863  0.49993137]
[ 0.49989706  0.50010294]
[ 0.49989706  0.50010294]
[ 0.49989706  0.50010294]
[ 0.5002402   0.4997598 ]]
old pa: 0.639920938888 old pb: 0.640079061112
new pa: 0.639978038581 new pb: 0.640021961419
gap_pa: 5.7099692811e-05 gap_pb: -5.70996928321e-05
eStepRet:
[[ 0.50001906  0.49998094]
[ 0.4999714   0.5000286 ]
[ 0.4999714   0.5000286 ]
[ 0.4999714   0.5000286 ]
[ 0.50006672  0.49993328]]
old pa: 0.639978038581 old pb: 0.640021961419
new pa: 0.639993899606 new pb: 0.640006100394
gap_pa: 1.58610249222e-05 gap_pb: -1.58610249225e-05
eStepRet:
[[ 0.5000053   0.4999947 ]
[ 0.49999206  0.50000794]
[ 0.49999206  0.50000794]
[ 0.49999206  0.50000794]
[ 0.50001853  0.49998147]]
old pa: 0.639993899606 old pb: 0.640006100394
new pa: 0.639998305446 new pb: 0.640001694554
gap_pa: 4.40584023775e-06 gap_pb: -4.40584023775e-06
stop em

(0.63999830544606295, 0.64000169455393696)```

(end)