【机器学习】【EM-3】EM算法(Expectation Maximization Algorithm)的python实现

1.EM算法简介

    EM算法的详解和样本集实例数学过程讲解,可以详见:https://blog.csdn.net/u012421852/article/details/79915908

2.EM算法的Python实现


# -*- coding: utf-8 -*-
"""
@author: 蔚蓝的天空Tom
Aim:实现EM算法(Expectation Maximization Algorithm)
"""

import numpy as np

class CEM(object):
    def __init__(self, samples, pa, pb, threshold):
        self.samples = samples
        self.pa = pa
        self.pb = pb
        self.eStepRet = None
        self.mStepRet = None
        self.threshold = threshold
        
        self.work()

    def likelihood_func(self, samples, p):
        '''似然函数'''
        ret = []
        for e in samples:
            ret.append(np.power(p, list(e).count(1))*np.power(1-p, list(e).count(0)))
        return ret
        
    def e_step(self):
        '''
        计算在模型参数pa,pb下观察数据来自投掷硬币a/b的概率
        '''
        #计算每轮投掷coin a和coin b的似然函数值(即每个样本发生概率的似然值)
        likelihooda = self.likelihood_func(self.samples, self.pa)
        #[0.00079626239999999997, 0.0005308416000000002, 0.0005308416000000002, 0.0005308416000000002, 0.0011943936]
        likelihoodb = self.likelihood_func(self.samples, self.pb)
        #[0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625]
        
        #计算每轮投掷来自coin a和coin b的概率
        self.eStepRet = np.array([e/sum(e) for e in zip(likelihooda, likelihoodb)])
        print('eStepRet:\n', self.eStepRet)
        #[[ 0.44914893  0.55085107]
        # [ 0.35215613  0.64784387]
        # [ 0.35215613  0.64784387]
        # [ 0.35215613  0.64784387]
        # [ 0.55016939  0.44983061]]  
        
        return
    
    def m_step(self):
        '''计算模型参数pa, pb的新估计值
        '''
        old_pa, old_pb = self.pa, self.pb
        print('old pa:', old_pa, 'old pb:', old_pb)
        h_a, t_a = 0, 0
        h_b, t_b = 0, 0
        for sample, e in zip(self.samples, self.eStepRet):
            h_a += list(sample).count(1) * e[0]
            t_a += list(sample).count(0) * e[0]
            h_b += list(sample).count(1) * e[1]
            t_b += list(sample).count(0) * e[1]
            
        self.pa = h_a / (h_a + t_a)
        self.pb = h_b / (h_b + t_b)
        print('new pa:', self.pa, 'new pb:', self.pb)
        gap_pa, gap_pb = self.pa - old_pa, self.pb - old_pb
        print('gap_pa:', gap_pa, 'gap_pb:', gap_pb)
        
        return gap_pa < self.threshold and gap_pb < self.threshold
    
    def work(self):
        self.e_step()
        stop = self.m_step()
        if (stop != True):
            return self.work()
        print('stop em\n')
        return

    def GetResult(self):
        return self.pa, self.pb

    pass
    

def CEM_manual():
    samples = np.array([[1,0,1,0,1,0,1,0,1,0], #coin a, 5+5-
                        [1,0,1,0,1,0,1,0,1,1], #coin b, 6+4-
                        [1,1,1,0,1,0,1,0,1,0], #coin a, 6+4-
                        [1,0,1,1,1,0,1,0,1,0], #coin b, 6+4-
                        [1,0,1,0,1,0,0,1,0,0]])#coin a, 4+6-
    
    samples = np.array([[1,0,1,1,1,0,1,0,1,0], #coin a, 5+5-
                        [1,0,1,1,1,0,1,0,1,1], #coin b, 6+4-
                        [1,1,1,1,1,0,1,0,1,0], #coin a, 6+4-
                        [1,0,1,1,1,1,1,0,1,0], #coin b, 6+4-
                        [1,0,1,1,1,0,0,1,0,0]])#coin a, 4+6-
    #可以知道
    #p(1|a) = (5+6+4)/30 = 0.5
    #p(1|b) = (6+6)/20 = 0.6
    
    #设置初始值
    pa, pb = 0.4, 0.5#p(1|a) = 0.4, p(1|b) = 0.5
    threshold = 0.00001
    em = CEM(samples, pa, pb, threshold)
    ret = em.GetResult()
    print(ret)
    
    return

if __name__=='__main__':
    CEM_manual()

3.运行结果

runfile('C:/Users/l13277/EM.py', wdir='C:/Users/l13277')
eStepRet:
 [[ 0.35215613  0.64784387]
 [ 0.26599464  0.73400536]
 [ 0.26599464  0.73400536]
 [ 0.26599464  0.73400536]
 [ 0.44914893  0.55085107]]
old pa: 0.4 old pb: 0.5
new pa: 0.621811879589 new pb: 0.648553523107
gap_pa: 0.221811879589 gap_pb: 0.148553523107
eStepRet:
 [[ 0.51017251  0.48982749]
 [ 0.4813223   0.5186777 ]
 [ 0.4813223   0.5186777 ]
 [ 0.4813223   0.5186777 ]
 [ 0.53895512  0.46104488]]
old pa: 0.621811879589 old pb: 0.648553523107
new pa: 0.63630074059 new pb: 0.643678879642
gap_pa: 0.0144888610009 gap_pb: -0.00487464346478
eStepRet:
 [[ 0.50320194  0.49679806]
 [ 0.49519623  0.50480377]
 [ 0.49519623  0.50480377]
 [ 0.49519623  0.50480377]
 [ 0.51120602  0.48879398]]
old pa: 0.63630074059 old pb: 0.643678879642
new pa: 0.638975359185 new pb: 0.64102463807
gap_pa: 0.00267461859504 gap_pb: -0.00265424157212
eStepRet:
 [[ 0.50088945  0.49911055]
 [ 0.49866584  0.50133416]
 [ 0.49866584  0.50133416]
 [ 0.49866584  0.50133416]
 [ 0.50311303  0.49688697]]
old pa: 0.638975359185 old pb: 0.64102463807
new pa: 0.639715379851 new pb: 0.640284620153
gap_pa: 0.000740020666399 gap_pb: -0.00074001791737
eStepRet:
 [[ 0.50024707  0.49975293]
 [ 0.4996294   0.5003706 ]
 [ 0.4996294   0.5003706 ]
 [ 0.4996294   0.5003706 ]
 [ 0.50086473  0.49913527]]
old pa: 0.639715379851 old pb: 0.640284620153
new pa: 0.639920938888 new pb: 0.640079061112
gap_pa: 0.000205559037088 gap_pb: -0.000205559040588
eStepRet:
 [[ 0.50006863  0.49993137]
 [ 0.49989706  0.50010294]
 [ 0.49989706  0.50010294]
 [ 0.49989706  0.50010294]
 [ 0.5002402   0.4997598 ]]
old pa: 0.639920938888 old pb: 0.640079061112
new pa: 0.639978038581 new pb: 0.640021961419
gap_pa: 5.7099692811e-05 gap_pb: -5.70996928321e-05
eStepRet:
 [[ 0.50001906  0.49998094]
 [ 0.4999714   0.5000286 ]
 [ 0.4999714   0.5000286 ]
 [ 0.4999714   0.5000286 ]
 [ 0.50006672  0.49993328]]
old pa: 0.639978038581 old pb: 0.640021961419
new pa: 0.639993899606 new pb: 0.640006100394
gap_pa: 1.58610249222e-05 gap_pb: -1.58610249225e-05
eStepRet:
 [[ 0.5000053   0.4999947 ]
 [ 0.49999206  0.50000794]
 [ 0.49999206  0.50000794]
 [ 0.49999206  0.50000794]
 [ 0.50001853  0.49998147]]
old pa: 0.639993899606 old pb: 0.640006100394
new pa: 0.639998305446 new pb: 0.640001694554
gap_pa: 4.40584023775e-06 gap_pb: -4.40584023775e-06
stop em

(0.63999830544606295, 0.64000169455393696)

(end)

猜你喜欢

转载自blog.csdn.net/u012421852/article/details/79915936
0条评论
添加一条新回复