在学习EM时,介绍了使用EM算法求解高斯混合模型(GMM:Gaussian Mixture Model,http://blog.csdn.net/foreseerwang/article/details/75222522),从而进行聚类的过程,并与k-means算法进行了对比,可以看到GMM模型的优势。
但是,GMM模型仍存在一些问题,譬如:必须要先知道类别数K。此时,可以通过引入GMM模型参数先验的方式,自动确定K。这就需要用到本文提到的VBEM算法了,详见PRML第10.2节和MLaPP第21.6节均有详细论述。
此外,在实践中发现,某些初始化下,简单的K-means和GMM/EM无法获得准确的聚类结果,估计可能和数据均衡度有关。请见下图,小图1是生成的GMM模型原始数据,每个高斯过程的参数与之前文章一致,但三类数据的比例是随机产生的。可以看到,在这种情况下,三类的比例失调,有一类特别少。
小图2和小图3分别是kmeans算法和GMM/EM算法聚类的结果,可以看到,与原始数据不符。
在这种情况下, 同样可以看到VBEM算法的优势。第二行的三个图是VBEM的结果。按照PRML和MLaPP书中所说,VBEM的类别数K需要选取大于实际类别数的值,这里选取K=6,小图4是直接聚6类的结果,用作VBEM的初始值。小图5是经过100次迭代的结果,仍然还有4类。小图6是迭代稳定后的结果。可以看到:
1. 自动生成了3类数据;
2. 3类的分布与原始数据非常一致。
这就是VBEM。世界正在变得越来越清晰...
代码如下:
clear all;
close all;
rng(2);
%% Parameters
N = 1000; % 总数据量
D = 2; % 数据维度
K = 3; % 类别数目
Pz = rand([K,1]); % 随机生成各类比例
Pz = Pz/sum(Pz);
% 数据初始化,与之前的EM聚类程序相同
mu = [1 2; -6 2; 7 1];
sigma=zeros(K,D,D);
sigma(1,:,:)=[2 -1.5; -1.5 2];
sigma(2,:,:)=[5 -2.; -2. 3];
sigma(3,:,:)=[1 0.1; 0.1 2];
%% Data Generation and display
x = zeros(N,D);
PzCDF1 = 0;
figure(1); subplot(2,3,1); hold on;
figure(2); hold on;
for ii = 1:K,
PzCDF2 = PzCDF1 + Pz(ii);
PzIdx1 = round(PzCDF1*N);
PzIdx2 = round(PzCDF2*N);
x(PzIdx1+1:PzIdx2,:) = mvnrnd(mu(ii,:), squeeze(sigma(ii,:,:)), PzIdx2-PzIdx1);
PzCDF1 = PzCDF2;
figure(1); subplot(2,3,1); hold on;
plot(x(PzIdx1+1:PzIdx2,1),x(PzIdx1+1:PzIdx2,2),'o');
end;
[~, tmpidx] = sort(rand(N,1));
x = x(tmpidx,:); % shuffle data
figure(1); subplot(2,3,1);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('1.Generated Data (original)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
figure(2);
plot(x(:,1),x(:,2),'o');
figure(2);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('Generated Data (original)');
xlabel('x1');
ylabel('x2');
fprintf('\n$$ Data generation and display completed...\n');
%% clustering: Matlab k-means
k_idx=kmeans(x,K); % 使用Matlab现有k-means算法
figure(1); subplot(2,3,2); hold on;
for ii=1:K,
idx=(k_idx==ii);
plot(x(idx,1),x(idx,2),'o');
center = mean(x(idx,:));
plot(center(1),center(2),'k*');
end;
axis([-10,10,-4,8]);
title('2.Clustering: Matlab k-means', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
fprintf('\n$$ K-means clustering completed...\n');
%% clustering: EM
% Refer to pp.351, MLaPP
% Pw: weight
% mu: u of Gaussion distribution
% sigma: Covariance matrix of Gaussion distribution
% r(i,k): responsibility; rk: sum of r over i
% px: p(x|mu,sigma)
% 上面的聚类结果作为EM算法的初始值
Pw=zeros(K,1);
for ii=1:K,
idx=(k_idx==ii);
Pw(ii)=sum(idx)*1.0/N;
mu(ii,:)=mean(x(idx,:));
sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
end;
px=zeros(N,K);
for jj=1:100, % 简单起见,直接循环,不做结束判断
for ii=1:K,
px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));
% 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误
end;
% E step
temp=px.*repmat(Pw',N,1);
r=temp./repmat(sum(temp,2),1,K);
% M step
rk=sum(r);
Pw=rk'/N;
mu=r'*x./repmat(rk',1,D);
for ii=1:K
sigma(ii,:,:)=x'*(repmat(r(:,ii),1,D).*x)/rk(ii)-mu(ii,:)'*mu(ii,:);
end;
end;
% display
[~,clst_idx]=max(px,[],2);
figure(1); subplot(2,3,3); hold on;
for ii=1:K,
idx=(clst_idx==ii);
plot(x(idx,1),x(idx,2),'o');
center = mean(x(idx,:));
sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
plot(center(1),center(2),'k*');
end;
axis([-10,10,-4,8]);
title('3.Clustering: GMM/EM', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
fprintf('\n$$ Gaussian Mixture using EM completed...\n');
%% Variational Bayes EM
% Refer to ch.10.2, PRML
% x: visible variable, N * D
% z: latent variable, N * K
% z: Pz, Ppi, alp0, alpk
% Pz = P(z|pi); PRML(10.37)
% Ppi = Dir(pi|alp0) PRML(10.39)
% x: Px, Pz, Ppi, mu, lambda, m0, beta0, W0, nu0
% Px = P(x|z, mu, lambda); 高斯分布 PRML(10.38)
% P(mu, lambda) = P(mu|lambda)*P(lambda) PRML(10.40)
% = N(mu|m0, (beta0*lambda)^-1) * Wi(lambda|W0, nu0)
% rho: N*K,定义参见PRML(10.46)
% r: N*K, responsibility; 归一化之后的rho,定义参见PRML(10.49)
% N_k: sum of r over n 定义参见PRML(10.51)
% xbar_k: 定义参见PRML(10.52)
% S_k 定义参见PRML(10.53)
K = 6; % 增加分类数,利用VBEM自动选择分类数
k_idx=kmeans(x,K); % 使用Matlab自带的k-means聚类,结果作为VBEM的初始值
figure(1); subplot(2,3,4); hold on;
for ii=1:K,
idx=(k_idx==ii);
plot(x(idx,1),x(idx,2),'o');
center = mean(x(idx,:));
plot(center(1),center(2),'k*');
mu(ii,:) = mean(x(idx,:));
sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));
% 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误,特使用自编函数GaussPDF
end;
axis([-10,10,-4,8]);
title('4.Clustering: VBEM (initial)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
% 初始化,具体定义参见PRML式(10.40)
alp0 = 0.0001; % alpha0,应<<1,以实现类别数自动筛选
m0 = 0;
beta0 = rand()+0.5; % 拍脑袋初始化
W0 = squeeze(mean(sigma));
W0inv = pinv(W0);
nu0 = D*2; % 拍脑袋初始化
S_k = zeros(K,D,D);
W_k = zeros(K,D,D);
E_mu_lmbd = zeros(N,K); % 即PRML中式(10.64)的等号左侧
r = px./repmat(sum(px,2),1,K); % N*K
N_k = ones(1,K)*(-100);
fprintf('\n');
for ii = 1:1000,
% M-step
N_k_new = sum(r); % 1*K,式(11.51)
N_k_new(N_k_new<N/1000.0)=1e-4; % 避免出现特别小或为零的Nk
if sum(abs(N_k_new-N_k))<0.001,
break; % early stop,如果Nk基本没变化了,则停止迭代
else
N_k = N_k_new;
end;
xbar_k = r'*x./repmat(N_k', 1, D); % K*D,PRML式(10.52)
for jj = 1:K,
dx = x-repmat(xbar_k(jj,:), N, 1); % N*D
S_k(jj,:,:) = dx'*(dx.*repmat(r(:,jj),1,D))/N_k(jj); % D*D,PRML式(10.53)
end;
alp_k = alp0 + N_k; % PRML式(10.58)
beta_k = beta0 + N_k; % PRML式(10.60)
m_k = (beta0*m0 + repmat(N_k',1,D).*xbar_k)./...
repmat(beta_k',1,D); % K*D,PRML式(10.61)
for jj = 1:K,
dxm = xbar_k(jj,:)-m0;
Wkinv = W0inv + N_k(jj)*squeeze(S_k(jj,:,:)) + ...
dxm'*dxm*beta0*N_k(jj)/(beta0+N_k(jj));
W_k(jj,:,:) = pinv(Wkinv); % K*D*D,PRML式(10.62)
end;
nu_k = nu0 + N_k; % 1*K,PRML式(10.63)
% E-step: 迭代计算r
alp_tilde = sum(alp_k);
E_ln_pi = psi(alp_k) - psi(alp_tilde); % PRML式(10.66)
E_ln_lambda = D*log(2)*ones(1,K);
for jj = 1:D,
E_ln_lambda = E_ln_lambda + psi((nu_k+1-jj)/2);
end;
for jj = 1:K,
E_ln_lambda(jj) = E_ln_lambda(jj) + ...
log(det(squeeze(W_k(jj,:,:)))); % PRML式(10.65)
dxm = x-repmat(m_k(jj,:),N,1); % N*D
Dbeta = D/beta_k(jj);
for nn = 1:N,
E_mu_lmbd(nn,jj) = Dbeta+nu_k(jj)*(dxm(nn,:)*...
squeeze(W_k(jj,:,:))*dxm(nn,:)'); % PRML式(10.64)
end;
end;
rho = exp(repmat(E_ln_pi,N,1)+repmat(E_ln_lambda,N,1)/2-...
E_mu_lmbd/2); % PRML式(10.46)
r = rho./repmat(sum(rho,2),1,K); % PRML式(10.49)
if mod(ii,10)==0,
fprintf('%3d loops finished.\n', ii);
end;
if ii == 100,
[~,clst_idx]=max(r,[],2);
figure(1); subplot(2,3,5); hold on;
for kk=1:K,
idx=(clst_idx==kk);
if sum(idx)/N>0.01,
plot(x(idx,1),x(idx,2),'o');
center = mean(x(idx,:));
plot(center(1),center(2),'k*');
end;
end;
axis([-10,10,-4,8]);
title('5.Clustering: VBEM (100 iter)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
end;
end;
[~,clst_idx]=max(r,[],2);
figure(1); subplot(2,3,6); hold on;
Nclst = 0;
for ii=1:K,
idx=(clst_idx==ii);
if sum(idx)/N>0.01,
Nclst = Nclst+1;
plot(x(idx,1),x(idx,2),'o');
center = mean(x(idx,:));
plot(center(1),center(2),'k*');
end;
end;
fprintf('\n$$ Using VBEM, totally %d clusters found.\n\n', Nclst);
axis([-10,10,-4,8]);
title('6.Clustering: VBEM (final)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');
GaussPDF代码如下:
function p = GaussPDF(x, mu, sigma)
[N, D] = size(x);
x_u = x-repmat(mu, N, 1);
p = zeros(N,1);
for ii=1:N,
p(ii) = exp(-0.5*x_u(ii,:)*pinv(sigma)*x_u(ii,:)')/...
sqrt(det(sigma)*(2*pi)^D);
end;
end