theoretical reference
[Machine Learning] EM - Expected Maximum (very detailed)
Sample introduction
There are c coins, randomly select one to throw n times, repeat m times and record the result.
Calculate the probability of heads for each coin based on the results of the toss.
The coin chosen each time is unknown.
process introduction
- Randomly initialize the coin with the probability head_p of being positive
- Find the probability of selecting a coin based on head_p selected_p
- Calculate new coin probability head_p based on selected_p
- If head_p converges, execute 5; otherwise, execute 2
- Finish
Code
import library
import random
import numpy as np
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
Sets the true value of the coin facing up
coin_num: number of coins
coins: the true value of the coin heads up
# 设置真实值
coin_num = 5
coins = []
for _ in range(coin_num):
coins.append(random.randint(0, 100)/100)
coins
Analog coin
- Randomly draw a coin
- throw n times
- Loop m times
1 is heads, 0 is tails
[Note]: The number of combinations is not multiplied when calculating the probability, so when n is too large, the probability will lose precision and become 0, resulting in optimization failure
0n = 100
m = 1000
coin_result = np.zeros((m, n), dtype=int)
c_selected_record = []
for i_m in range(m):
# 选择硬币
coin_p = random.choice(coins)
c_selected_record.append(coin_p)
for i_n in range(n):
# 开始投掷
coin_result[i_m, i_n] = 1 if random.random() < coin_p else 0
coin_result.shape, Counter(c_selected_record)
EM algorithm
Initialization: Randomly initialize the probability of heads of the coin
- Step E: Calculate the expectation of the current coin
- Step M: Update coin parameters
In order to facilitate the implementation, the probability calculation method is modified, and the result remains unchanged:
p n 1 × ( 1 − p ) n 2 = e x p ( n 1 × log ( p ) + n 2 × log ( 1 − p ) ) p^{n_1}\times (1-p)^{n_2} \\ =exp(n_1 \times \log(p) + n_2 \times \log(1-p)) pn1×(1−p)n2=exp(n1×log(p)+n2×log(1−p))
[Note]: The predicted coins are not one-to-one correspondence, the order will change
ini_coin_theta = np.array([random.randint(1, 99)/100 for _ in range(coin_num)])
# coin_theta = np.array([0.2, 0.9])
print('ini coin:', ini_coin_theta)
def E(coin_theta, coin_result):
h_e_sum = np.zeros_like(coin_theta)
t_e_sum = np.zeros_like(coin_theta)
h_num = coin_result.sum(1)[:, None]
t_num = coin_result.shape[1] - h_num
# 可以评估每个硬币的得分
coin_selected_p = h_num @ np.log(coin_theta[None]) + t_num @ np.log(1 - coin_theta[None])
coin_selected_p = np.exp(coin_selected_p)
coin_selected_p = coin_selected_p / coin_selected_p.sum(1)[:, None]
h_e = coin_selected_p * h_num
t_e = coin_selected_p * t_num
return h_e.sum(0), t_e.sum(0), coin_selected_p
def M(h_e_sum, t_e_sum):
return h_e_sum / (h_e_sum + t_e_sum)
max_step=1000
coin_result = np.array(coin_result)
h_e_record = []
t_e_record = []
theta_record = []
delta_record = []
coin_theta = ini_coin_theta
for i in tqdm(range(max_step)):
h_e_sum, t_e_sum, coin_selected_p = E(coin_theta, coin_result)
h_e_record.append(h_e_sum)
t_e_record.append(t_e_sum)
new_coin_theta = M(h_e_sum, t_e_sum)
theta_record.append(new_coin_theta)
delta = ((new_coin_theta - coin_theta)**2).sum()
delta_record.append(delta)
# print(new_coin_theta)
if delta < 1e-10:
break
coin_theta = new_coin_theta
h_e_record = np.array(h_e_record)
t_e_record = np.array(t_e_record)
theta_record = np.array(theta_record)
delta_record = np.array(delta_record)
i, coin_theta, coins
'''
(36,
array([0.62988197, 0.43099465, 0.84265886, 0.99086422, 0.53815304]),
[0.84, 0.99, 0.63, 0.44, 0.54])
'''
Show the parameter change process
def plot_list(f, x, y, labels, title):
f.set_title(title)
for i in range(y.shape[1]):
f.plot(x, y[:, i], label = labels[i], linestyle='--')
index = range(0, i+1)
labels = list(range(coin_theta.shape[0]))
figure, axes = plt.subplots(2, 2, figsize=(12,12), dpi=80)
axes[0, 0].set_title("delta")
# 与上一步结果的差别
axes[0, 0].plot(index, delta_record, label="delta")
# 硬币正面的概率
plot_list(axes[0, 1], index, theta_record, labels=labels, title="theta")
# 每个硬币正面的加权和
plot_list(axes[1, 0], index, h_e_record, labels=labels, title="h_e")
# 每个硬币反面的加权和
plot_list(axes[1, 1], index, t_e_record, labels=labels, title="t_e")
for axe in axes:
for a in axe:
a.legend()